diff --git a/examples/multimodal/config.py b/examples/multimodal/config.py index d242651114..4524df4480 100644 --- a/examples/multimodal/config.py +++ b/examples/multimodal/config.py @@ -60,6 +60,21 @@ def get_language_model_config(config): config.apply_rope_fusion = False config.attention_softmax_in_fp32 = True config.ffn_hidden_size = 14336 + elif config.language_model_type == "yi-34b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 20480 + else: + raise ValueError(f"unknown language model type {config.language_model_type}") return config @@ -107,6 +122,30 @@ def get_vision_model_config(config, apply_query_key_layer_scaling): config.apply_rope_fusion = False config.qk_layernorm = False config.layernorm_epsilon = 1e-6 + elif config.vision_model_type == "internvit": + config.num_layers = 45 + config.num_attention_heads = 32 # Padded for TP=8. + config.num_query_groups = 32 # Padded for TP=8. + config.kv_channels = 128 + config.add_bias_linear = True + config.add_qkv_bias = False + config.hidden_size = 3200 + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + config.ffn_hidden_size = 12800 + config.gated_linear_unit = False + config.activation_func = torch.nn.functional.gelu + config.layernorm_zero_centered_gamma = False + config.apply_query_key_layer_scaling = apply_query_key_layer_scaling + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.attention_softmax_in_fp32 = True + config.normalization = 'RMSNorm' + config.layernorm_epsilon = 1e-6 + config.apply_rope_fusion = False + else: + raise ValueError(f"unknown vision model type {config.vision_model_type}") + return config @@ -128,6 +167,12 @@ def get_vision_projection_config(config, hidden_size): elif config.language_model_type == "mistral_7b": config.ffn_hidden_size = 14336 config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "yi-34b": + config.ffn_hidden_size = 20480 + config.normalization = 'LayerNorm' + config.activation_func = torch.nn.functional.gelu + else: + raise ValueError(f"unknown language model type {config.language_model_type}") return config diff --git a/examples/multimodal/image_processing.py b/examples/multimodal/image_processing.py index a4541576ae..7e0dcdfe74 100644 --- a/examples/multimodal/image_processing.py +++ b/examples/multimodal/image_processing.py @@ -7,13 +7,18 @@ from torchvision.transforms import Compose, RandAugment, RandomResizedCrop, Resize, ToPILImage -# Imagenet's mean and std. -pixel_mean = [123.675, 116.28, 103.53] -pixel_std = [58.395, 57.12, 57.375] - # Reshape for broadcasting. -pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1) -pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1) +pixel_mean_clip = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) +pixel_std_clip = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) + +pixel_mean_siglip = torch.Tensor([127.5, 127.5, 127.5]).view(-1, 1, 1) +pixel_std_siglip = torch.Tensor([127.5, 127.5, 127.5]).view(-1, 1, 1) + +pixel_statistics = { + "clip": (pixel_mean_clip, pixel_std_clip), + "siglip": (pixel_mean_siglip, pixel_std_siglip), + "internvit": (pixel_mean_clip, pixel_std_clip), +} def convert_to_rgb(image): @@ -36,12 +41,14 @@ def _transform_test(img_h, img_w): ]) -def standardize_image(img): +def standardize_image(img, mean, std): """Standardize image pixel values.""" - return (torch.Tensor(np.array(img)).permute(2, 0, 1) - pixel_mean) / pixel_std + return (torch.Tensor(np.array(img)).permute(2, 0, 1) - mean) / std + +def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, vision_model_type="clip"): + pixel_mean, pixel_std = pixel_statistics[vision_model_type] -def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False): if use_tiling: assert img_h == img_w, "dynamic tiling expects equal tile height and width" imgs = dynamic_preprocess(img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail) @@ -60,7 +67,7 @@ def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, u img = visual_transform(img) # Standardize pixel values. - img = standardize_image(img) + img = standardize_image(img, pixel_mean, pixel_std) # Pad to target image size. delta_h, delta_w = img_h - scaled_h, img_w - scaled_w diff --git a/examples/multimodal/model.py b/examples/multimodal/model.py index ab700a19f5..28bb6bcb84 100644 --- a/examples/multimodal/model.py +++ b/examples/multimodal/model.py @@ -37,7 +37,7 @@ def model_provider( num_image_embeddings = get_num_image_embeddings( args.img_h, args.img_w, args.patch_dim, args.vision_model_type, - args.disable_vision_class_token, 1 + args.disable_vision_class_token, 1, args.pixel_shuffle, ) old_seq_length = args.seq_length args.seq_length = args.encoder_seq_length = num_image_embeddings @@ -92,6 +92,9 @@ def model_provider( vision_transformer_layer_spec = get_layer_spec( is_vit=True, normalization=vision_config.normalization ) + elif vision_model_type == "internvit": + from nvlm.internvit import get_internvit_layer_spec + vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te) else: raise RuntimeError("unsupported vision model type", vision_model_type) diff --git a/examples/multimodal/model_converter/internvit_converter.py b/examples/multimodal/model_converter/internvit_converter.py new file mode 100644 index 0000000000..48404c2084 --- /dev/null +++ b/examples/multimodal/model_converter/internvit_converter.py @@ -0,0 +1,162 @@ +import argparse +import os + +import torch +from transformers import AutoModel + + +def convert(model_name, output_path, tensor_parallel_size, use_te): + """Convert InternViT HF checkpoint to mcore.""" + hf_model = AutoModel.from_pretrained( + model_name, + trust_remote_code=True + ) + + hf_state_dict = hf_model.state_dict() + new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] + + hidden_size = 3200 + num_heads = 25 + dim = 128 + + order = torch.ones(3 * hidden_size).long() + + for j in range(num_heads): + for i in range(dim): + order[i + dim*3*j] = j*dim+i + order[dim + i + dim*3*j] = j*dim+i+num_heads*dim + order[dim*2 + i + dim*3*j] = j*dim+i+num_heads*dim*2 + + for name, tensor in hf_state_dict.items(): + # Map parameter names to ones used in megatron. + new_name = "" + new_tensor = tensor + + # This is used for chunking some tensors to target tensor parallel size. + chunk_dim = None + + if "embeddings.class_embedding" in name: + new_name = "class_token" + elif "embeddings.patch_embedding.weight" in name: + new_name = "conv1.weight" + elif "embeddings.patch_embedding.bias" in name: + new_name = "conv1.bias" + elif "embeddings.position_embedding" in name: + new_name = "position_embeddings.weight" + new_tensor = new_tensor.squeeze(0) + elif "encoder.layers" in name: + layer_idx = name.split(".")[2] + + base = f"decoder.layers.{layer_idx}" + + head_dim = 128 + + if tensor_parallel_size == 1: + num_padded_heads = 25 + elif tensor_parallel_size == 8: + # Note: 25 is not divisible by 8 and we don't currently support uneven heads split with tensor parallelism. + # So we pad with dummy all-zero heads. Please use a nice even number of attention heads in your model. + num_padded_heads = 32 + else: + raise NotImplementedError("invalid tensor parallel size value:", tensor_parallel_size) + + if "ls1" in name: + new_name = f"{base}.ls1" + elif "ls2" in name: + new_name = f"{base}.ls2" + elif "attn.qkv.weight" in name: + new_name = f"{base}.self_attention.linear_qkv.weight" + num_tensors = 3 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros((padded_dim, new_tensor.shape[-1]), dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:new_tensor.shape[0], :] = new_tensor[order] + new_tensor = padded_tensor + chunk_dim = 0 + elif "attn.q_norm.weight" in name: + new_name = f"{base}.self_attention.q_layernorm.weight" + num_tensors = 1 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:new_tensor.shape[0]] = new_tensor + new_tensor = padded_tensor + chunk_dim = 0 + elif "attn.k_norm.weight" in name: + new_name = f"{base}.self_attention.k_layernorm.weight" + num_tensors = 1 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:new_tensor.shape[0]] = new_tensor + new_tensor = padded_tensor + chunk_dim = 0 + elif "attn.proj.weight" in name: + new_name = f"{base}.self_attention.linear_proj.weight" + num_tensors = 1 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros((new_tensor.shape[0], padded_dim), dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:, :new_tensor.shape[-1]] = new_tensor + new_tensor = padded_tensor + chunk_dim = 1 + elif "attn.proj.bias" in name: + new_name = f"{base}.self_attention.linear_proj.bias" + elif "mlp.fc1.weight" in name: + new_name = f"{base}.mlp.linear_fc1.weight" + chunk_dim = 0 + elif "mlp.fc1.bias" in name: + new_name = f"{base}.mlp.linear_fc1.bias" + chunk_dim = 0 + elif "mlp.fc2.weight" in name: + new_name = f"{base}.mlp.linear_fc2.weight" + chunk_dim = 1 + elif "mlp.fc2.bias" in name: + new_name = f"{base}.mlp.linear_fc2.bias" + elif "norm1" in name: + new_name = f"{base}.input_layernorm.weight" + elif "norm2" in name: + new_name = f"{base}.pre_mlp_layernorm.weight" + else: + raise RuntimeError("unexpected transformer layer name", name) + else: + raise RuntimeError("unexpected layer name", name) + + assert new_name != "", f"unexpected layer name {name}" + + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. + extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) + if use_te and is_extra_state_layer: + layer = new_name.split(".")[-2] + if layer in extra_state_layers: + extra_state_name = ( + new_name[: new_name.rfind(".") + 1] + "_extra_state" + ) # Replace the weight name. + for i in range(tensor_parallel_size): + new_state_dicts[i]["model"][extra_state_name] = None + + if chunk_dim is None: + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] + else: + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) + + for i in range(tensor_parallel_size): + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() + + for i in range(tensor_parallel_size): + output_dir_tp = os.path.join(output_path, f"iter_0000001/mp_rank_0{i}") + os.makedirs(output_dir_tp, exist_ok=True) + output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") + torch.save(new_state_dicts[i], output_path_tp) + print("saved file", output_path_tp) + + print("done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="InternVIT HuggingFace to Mcore converter") + parser.add_argument("--model-name", type=str, default="OpenGVLab/InternViT-6B-448px-V1-5", help="Model name in HuggingFace") + parser.add_argument("--output-dir", type=str, required=True, help="Output directory for the mcore model.") + parser.add_argument("--use-te", action="store_true", default=True) + parser.add_argument("--tensor-parallel-size", type=int, required=True) + + args = parser.parse_args() + + convert(args.model_name, args.output_dir, args.tensor_parallel_size, args.use_te) diff --git a/examples/multimodal/model_converter/vision_model_tester.py b/examples/multimodal/model_converter/vision_model_tester.py new file mode 100644 index 0000000000..ef36dd5f9e --- /dev/null +++ b/examples/multimodal/model_converter/vision_model_tester.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import argparse +import os +import sys + +# Add megatron and the multimodal example to the path. +sys.path.append( + os.path.abspath( + os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, os.path.pardir) + ) +) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + +import torch +from transformers import AutoModel + +from examples.multimodal.model import model_provider +from examples.multimodal.multimodal_args import add_multimodal_extra_args +from megatron.training import get_model +from megatron.training.checkpointing import load_checkpoint +from megatron.training.initialize import initialize_megatron + + +def run_mcore_vision(model_path): + """Run mcore vision model.""" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + # Megatron has some mandatory flags. + sys.argv = [ + "ignore_me.py", + "--micro-batch-size=1", + "--num-layers=2", + "--vision-model-type=internvit", + "--language-model-type=mistral_7b", + "--tokenizer-prompt-format=mistral", + "--tokenizer-type=MultimodalTokenizer", + "--tokenizer-model=mistralai/Mistral-7B-Instruct-v0.3", + "--vocab-size=1024", + "--hidden-size=64", + "--num-attention-heads=8", + "--seq-length=1024", + "--decoder-seq-length=2048", + "--max-position-embeddings=2048", + "--bf16", + "--img-h=448", + "--img-w=448", + "--patch-dim=14", + "--tensor-model-parallel-size=8", + "--use-te", + f"--pretrained-checkpoint={model_path}", + ] + + initialize_megatron(extra_args_provider=add_multimodal_extra_args) + + def wrapped_model_provider(pre_process, post_process): + return model_provider(pre_process, post_process, parallel_output=False) + + # Set up model and load checkpoint. + model = get_model(wrapped_model_provider, wrap_with_ddp=False) + + vision_model = model[0].module.vision_model + + load_checkpoint([vision_model], None, None) + + vision_model.eval() + + images = torch.ones((1, 3, 448, 448), dtype=torch.bfloat16, device="cuda") + + output = vision_model(images) + + return output + + +def run_hf_vision(model_name): + """Run HF vision model.""" + model = ( + AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True) + .cuda() + .eval() + ) + + images = torch.ones((1, 3, 448, 448), dtype=torch.bfloat16, device="cuda") + + outputs = model(images, return_dict=True) + + return outputs + + +def main(mcore_model, hf_model): + """Compare vision model outputs between mcore and HF given the same fixed input.""" + mcore = run_mcore_vision(mcore_model) + + if torch.distributed.get_rank() == 0: + hf = run_hf_vision(hf_model) + hf = hf["last_hidden_state"] + + # Compare logits. Due to different attention implementations and other details, + # there will be numerical differences. + diff = (mcore - hf).abs() + mean_diff = diff.mean().item() + max_diff = diff.max().item() + print(f"mean diff {mean_diff}, max diff {max_diff}") + assert mean_diff < 0.1, "mean output difference is greater than expected" + assert max_diff < 50, "max output difference is greater than expected" + + print("lgtm") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Check mcore vision model output vs. HF numerically.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--mcore-model", type=str, required=True, help="directory for mcore model weights" + ) + parser.add_argument("--hf-model", type=str, required=True, help="Model name in HF") + + args = parser.parse_args() + + main(args.mcore_model, args.hf_model) diff --git a/examples/multimodal/multimodal_args.py b/examples/multimodal/multimodal_args.py index ca38f216bc..1068e92e32 100644 --- a/examples/multimodal/multimodal_args.py +++ b/examples/multimodal/multimodal_args.py @@ -53,5 +53,6 @@ def add_multimodal_extra_args(parser): required=True, help="Prompt format to use with the tokenizer.", ) + group.add_argument("--pixel-shuffle", action="store_true", default=False) return parser diff --git a/examples/multimodal/nvlm/internvit.py b/examples/multimodal/nvlm/internvit.py new file mode 100644 index 0000000000..1f28373ca2 --- /dev/null +++ b/examples/multimodal/nvlm/internvit.py @@ -0,0 +1,256 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""" +NOTE: NVLM uses InternViT with tensor parallel (TP) size = 8. +Since InternViT has 25 attention heads and Megatron currently requires the number of attention heads +to be divisible by the TP size, we add 7 dummy zero attention heads to have 32 attention heads. + +This workaround requires some changes to how we compute RMSNorm, Attention etc. + +Additionally, InternViT introduces some unique features like Layer Scaling. + +Those code changes are gathered here. +""" +from functools import partial + +import torch + +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TERowParallelLinear, +) +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + + +class InternViTRMSNorm(torch.nn.Module): + + def __init__( + self, + config, + hidden_size: int, + eps: float = 1e-6, + sequence_parallel: bool = False, + compute_var: bool = False, + ): + """Custom RMSNorm for InternViT. + + Args: + config (TransformerConfig): Config. + hidden_size (int): Input hidden size. + eps (float): epsilon to use for the norm, default to 1e-6 + sequence_parallel (bool): Set to true if sequence parallelism is being used, + this marks the weights as needing to be allreduced. + compute_var (bool): Indicator to compute statistic manually. + """ + super().__init__() + self.config = config + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self._compute_var = compute_var + + assert not sequence_parallel, "Sequence parallelism is not supported with InternViT." + + setattr(self.weight, 'sequence_parallel', sequence_parallel) + + def _norm(self, x, var): + if var is None: + var = x.pow(2).mean(-1, keepdim=True) + + return x * torch.rsqrt(var + self.eps) + + def forward(self, x): + """Run RMSNorm with an option to compute custom statistic.""" + var = None + if self._compute_var: + unpadded_hidden_size = self.config.hidden_size # 3200 + max_dim = x.shape[-1] # 128 + + x = x.reshape(x.size(0), x.size(1), -1) + var = self._gather_var(x.float().pow(2), max_dim) / unpadded_hidden_size + + output = self._norm(x.float(), var).type_as(x) + output = output * self.weight + + if self._compute_var: + output = output.reshape(output.size(0), output.size(1), -1, max_dim) + + return output + + def _gather_var(self, input_, max_dim, valid_ranks=6): + """Compute statistic across the non-dummy heads.""" + world_size = get_tensor_model_parallel_world_size() + assert world_size == 8, "tested only with TP=8" + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = get_tensor_model_parallel_rank() + + if rank < valid_ranks: # Ranks 0-5 have 24 non-dummy attention heads. + var = input_.sum(-1, keepdim=True) + elif rank == valid_ranks: # Rank 6 has 1 non-dummy attention head. + var = input_[..., :max_dim].sum(-1, keepdim=True) + else: + var = input_.sum(-1, keepdim=True) * 0.0 # Zero-out the dummy heads. + + tensor_list = [torch.empty_like(var) for _ in range(world_size)] + tensor_list[rank] = var + torch.distributed.all_gather(tensor_list, var, group=get_tensor_model_parallel_group()) + + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output.sum(-1, keepdim=True) + + +def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ) + + +# Handle InternViT's layer scaling. +def _bias_dropout_add_func_internvit(ls, x_with_bias, residual, prob, training): + x, bias = x_with_bias # unpack + residual = residual if residual.dtype == x.dtype else residual.to(x.dtype) + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out * ls + return out + else: + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out * ls + return out + + +def bias_dropout_add_unfused_internvit(ls, training): + """Bias-dropout-add as in Megatron but with added LayerScaling handling.""" + + def _bias_dropout_add(x_with_bias, residual, prob): + return _bias_dropout_add_func_internvit(ls, x_with_bias, residual, prob, training) + + return _bias_dropout_add + + +def get_bias_dropout_add_internvit(ls, training, fused): + """Bias-dropout-add as in Megatron but with added LayerScaling handling.""" + assert not fused, "Fused bias-dropout-add not implemented for InternViT." + return bias_dropout_add_unfused_internvit(ls, training) + + +# Add InternViT specialties to our default TransformerLayer. +class InternViTTransformerLayer(TransformerLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ls1 = torch.nn.Parameter(torch.ones(self.config.hidden_size)) + self.ls2 = torch.nn.Parameter(torch.ones(self.config.hidden_size)) + + self.self_attn_bda = partial(self.self_attn_bda, self.ls1) + self.mlp_bda = partial(self.mlp_bda, self.ls2) + + +# Override a few things that are special in InternViT and not supported by the SelfAttention class. +class InternViTSelfAttention(SelfAttention): + def __init__( + self, config: TransformerConfig, submodules: SelfAttentionSubmodules, *args, **kwargs + ): + super().__init__(config=config, submodules=submodules, *args, **kwargs) + + # Need to override linear_qkv, q_layernorm and k_layernorm. + qkv_bias = False + + self.linear_qkv = build_module( + submodules.linear_qkv, + self.config.hidden_size, + self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + ) + + qk_layernorm_hidden_size = ( + self.hidden_size_per_attention_head * self.num_attention_heads_per_partition + ) # 512 for internvit + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=qk_layernorm_hidden_size, + config=self.config, + eps=self.config.layernorm_epsilon, + compute_var=True, + ) + + self.k_layernorm = build_module( + submodules.k_layernorm, + hidden_size=qk_layernorm_hidden_size, + config=self.config, + eps=self.config.layernorm_epsilon, + compute_var=True, + ) + + +class InternViTTEDotProductAttention(TEDotProductAttention): + """Adjusted Attention for InternViT""" + + def forward(self, *args, **kwargs): + """Regular TEDotProductAttention + zero-out dummy attention heads.""" + out = super().forward(*args, **kwargs) + + # This makes sure the dummy attention heads are zeroed out. + mask = torch.ones_like(out, dtype=out.dtype, device=out.device) + rank = get_tensor_model_parallel_rank() + max_dim = out.shape[-1] # 128 + valid_ranks = 6 + + if rank == valid_ranks: + mask[..., max_dim:] *= 0.0 + elif rank > valid_ranks: + mask *= 0.0 + out *= mask + + return out + + +def get_internvit_layer_spec(use_te) -> ModuleSpec: + mlp = get_mlp_module_spec(use_te) # no norm + + return ModuleSpec( + module=InternViTTransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=InternViTRMSNorm, + self_attention=ModuleSpec( + module=InternViTSelfAttention, + params={"attn_mask_type": AttnMaskType.no_mask}, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear if use_te else ColumnParallelLinear, + core_attention=TEDotProductAttention if use_te else DotProductAttention, + linear_proj=TERowParallelLinear if use_te else RowParallelLinear, + q_layernorm=InternViTRMSNorm, + k_layernorm=InternViTRMSNorm, + ), + ), + self_attn_bda=get_bias_dropout_add_internvit, + pre_mlp_layernorm=InternViTRMSNorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add_internvit, + ), + ) diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index 47c7378e0e..3a8d80b42e 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -143,6 +143,7 @@ def generate_samples(model, config: EvaluationConfig, print_output): args.vision_model_type, args.disable_vision_class_token, 1, + args.pixel_shuffle, ) for idx, (imgs, num_tiles, sample_id, question, answers, metadata) in enumerate(dataloader): diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index e6c1e48be0..6a6f7f3325 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -94,6 +94,7 @@ def __init__( language_rotary_base: int = 10000, language_rope_scaling: bool = False, image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX, + pixel_shuffle: bool = False, ) -> None: super().__init__(config=language_transformer_config) @@ -198,9 +199,11 @@ def __init__( vision_transformer_config.vision_model_type, drop_vision_class_token, class_token_len, + pixel_shuffle, ) self.image_token_index = image_token_index + self._pixel_shuffle = pixel_shuffle def shared_embedding_or_output_weight(self): """This is a convenience method to surface the language model's word embeddings, which is @@ -558,6 +561,12 @@ def forward( image_embeddings = self.vision_model(images) # [num_tiles, img_seq_len, h_vision] if self._drop_vision_class_token: image_embeddings = image_embeddings[:, self.vision_model.class_token_len :, :] + + if self._pixel_shuffle: + image_embeddings = pixel_shuffle( + image_embeddings + ) # [num_tiles, img_seq_len_shuffled, h_vision_shuffled] + # contiguous() required as `permute` can sparsify the tensor and this breaks pipelining image_embeddings = image_embeddings.permute( 1, 0, 2 @@ -676,3 +685,37 @@ def _load_state_dict_hook_ignore_param_names( f"{param_name} being removed from incompatible_keys.missing_keys in LlavaModel" ) incompatible_keys.missing_keys.remove(param_name) + + +# pylint: disable-next=line-too-long +# Based on https://github.com/OpenGVLab/InternVL/blob/c7c5af1a8930b4862afe8ed14672307082ef61fa/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py#L218 +# Copyright (c) 2023 OpenGVLab. +def pixel_shuffle(x, scale_factor=0.5, version=2): + """Pixel shuffle based on InternVL but adapted for our use case. + + Args: + x (torch.Tensor): Vision model outputs [num_tiles, img_seq_len, h_vision] + version (int): Implementation version. + + Returns: + Shuffled vision model outputs [num_tiles, (sq ** 2) * (scale ** 2), h_vision / (scale ** 2)] + """ + h = w = int(x.shape[1] ** 0.5) # sq + x = x.reshape(x.shape[0], h, w, -1) # [num_tiles, sq, sq, h_vision] + + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view( + n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)) + ) + + if version == 2: + x = x.permute(0, 2, 1, 3).contiguous() + + x = x.reshape(x.shape[0], -1, x.shape[-1]) + + return x diff --git a/megatron/core/models/vision/clip_vit_model.py b/megatron/core/models/vision/clip_vit_model.py index 0661f1ef55..5880b2bb5e 100644 --- a/megatron/core/models/vision/clip_vit_model.py +++ b/megatron/core/models/vision/clip_vit_model.py @@ -51,7 +51,7 @@ def __init__( ) -> None: error_msg = f"CLIPViTModel model subtype {model_subtype} is not supported." - assert model_subtype in ["clip", "siglip"], error_msg + assert model_subtype in ["clip", "siglip", "internvit"], error_msg if model_subtype == "siglip": assert class_token_len == 0, "SigLIP does not support class tokens." @@ -90,7 +90,7 @@ def __init__( ) conv_bias = False padding = 0 - if model_subtype == "siglip": + elif model_subtype == "siglip": self.ln_post = build_module( ln_post_impl, config=transformer_config, @@ -99,6 +99,11 @@ def __init__( ) conv_bias = True padding = "valid" + elif model_subtype == "internvit": + conv_bias = True + padding = 0 + else: + raise ValueError(f"unsupported vision model type {model_subtype}") self.conv1 = torch.nn.Conv2d( in_channels=3, @@ -182,17 +187,28 @@ def forward( def get_num_image_embeddings( - img_h, img_w, patch_dim, vision_model_type, disable_vision_class_token, class_token_len + img_h, + img_w, + patch_dim, + vision_model_type, + disable_vision_class_token, + class_token_len, + pixel_shuffle=False, ): """Get the number of image embeddings per image tile.""" if vision_model_type == "siglip": keep_class_token = False - elif vision_model_type == "clip": + elif vision_model_type in ("clip", "internvit"): keep_class_token = not disable_vision_class_token + else: + raise ValueError(f"unsupported vision model: {vision_model_type}") num_patches_per_dim_h = img_h // patch_dim num_patches_per_dim_w = img_w // patch_dim num_patches = num_patches_per_dim_h * num_patches_per_dim_w num_image_embeddings_per_tile = num_patches + (class_token_len if keep_class_token else 0) + if pixel_shuffle: + num_image_embeddings_per_tile = int(num_image_embeddings_per_tile * (0.5**2)) + return num_image_embeddings_per_tile diff --git a/pretrain_vlm.py b/pretrain_vlm.py index 009e86e47f..d9bf308bfe 100644 --- a/pretrain_vlm.py +++ b/pretrain_vlm.py @@ -50,7 +50,8 @@ def model_provider( vision_model_type = "clip" num_image_embeddings = get_num_image_embeddings( - args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token, 1 + args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token, + class_token_len=1, pixel_shuffle=False, ) old_seq_length = args.seq_length