From 78d1ab6d97527a73a50e6e352ea934881335da29 Mon Sep 17 00:00:00 2001 From: lkk12014402 Date: Thu, 14 Nov 2024 17:32:22 +0000 Subject: [PATCH 1/5] support llava1.5 lora finetuning. --- examples/image-to-text/README.md | 36 +- .../image-to-text/run_llava_lora_finetune.py | 577 ++++++++++++++++++ .../models/llava/modeling_llava.py | 125 ++-- 3 files changed, 693 insertions(+), 45 deletions(-) create mode 100644 examples/image-to-text/run_llava_lora_finetune.py diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 5916de4a29..ba76460b3c 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -365,6 +365,40 @@ python3 ../gaudi_spawn.py \ --lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"' ``` +Here are single-/multi-device command examples for llava-hf/llava-1.5-7b-hf. + +``` +python3 run_llava_lora_finetune.py \ + --model_name_or_path llava-hf/llava-1.5-7b-hf \ + --dataset_name nielsr/docvqa_1200_examples \ + --bf16 True \ + --output_dir ./model_lora_llava \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --weight_decay 0.01 \ + --logging_steps 25 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 5e-5 \ + --warmup_steps 50 \ + --lr_scheduler_type "constant" \ + --input_column_names 'image' 'query' \ + --output_column_names 'answers' \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --lora_rank=8 \ + --lora_alpha=8 \ + --lora_dropout=0.1 \ + --max_seq_length=512 \ + --use_hpu_graphs_for_inference \ + --low_cpu_mem_usage True +``` + ## Multi-HPU inference To enable multi-card inference, you must set the environment variable `PT_HPU_ENABLE_LAZY_COLLECTIVES=true`, @@ -405,4 +439,4 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json PT_HPU_ENABLE_LAZY_COLLECTI --bf16 \ --use_flash_attention \ --flash_attention_recompute -``` \ No newline at end of file +``` diff --git a/examples/image-to-text/run_llava_lora_finetune.py b/examples/image-to-text/run_llava_lora_finetune.py new file mode 100644 index 0000000000..528b5f61a0 --- /dev/null +++ b/examples/image-to-text/run_llava_lora_finetune.py @@ -0,0 +1,577 @@ +# Apache v2 license +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +lora fine tuning script for image-to-text case +Adapted from the following sources: +https://colab.research.google.com/drive/1rm3AGquGEYXfeeizE40bbDtcWh5S4Nlq?usp=sharing +""" + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import List, Optional + +import Levenshtein +import torch +import transformers +from datasets import load_dataset +from peft import LoraConfig, get_peft_model +from transformers import ( + AutoConfig, + AutoModelForVision2Seq, + AutoProcessor, + HfArgumentParser, +) +from transformers.trainer_utils import is_main_process + +from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments + + +try: + from optimum.habana.utils import check_optimum_habana_min_version +except ImportError: + + def check_optimum_habana_min_version(*a, **b): + return () + + +os.environ["WANDB_DISABLED"] = "true" + +logger = logging.getLogger(__name__) + +# Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. +check_optimum_habana_min_version("1.10.0") + + +def normalized_levenshtein(s1, s2): + len_s1, len_s2 = len(s1), len(s2) + distance = Levenshtein.distance(s1, s2) + return distance / max(len_s1, len_s2) + + +def similarity_score(a_ij, o_q_i, tau=0.5): + nl = normalized_levenshtein(a_ij, o_q_i) + return 1 - nl if nl < tau else 0 + + +def average_normalized_levenshtein_similarity(ground_truth, predicted_answers): + assert len(ground_truth) == len(predicted_answers), "Length of ground_truth and predicted_answers must match." + + N = len(ground_truth) + total_score = 0 + + for i in range(N): + a_i = ground_truth[i] + o_q_i = predicted_answers[i] + if o_q_i == "": + print("Warning: Skipped an empty prediction.") + max_score = 0 + else: + max_score = max(similarity_score(a_ij, o_q_i) for a_ij in a_i) + + total_score += max_score + + return total_score / N + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/processor we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name"}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + token: Optional[str] = field( + default=None, + metadata={"help": "auth token for private models"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether to trust the execution of code from datasets/models defined on the Hub." + " This option should only be set to `True` for repositories you trust and in which you have read the" + " code, as it will execute code present on the Hub on your local machine." + ) + }, + ) + use_cache: bool = field( + default=True, + metadata={ + "help": ( + "Whether or not the model should return the last key/values attentions (not used by all models)." + "Only relevant if `config.is_decoder=True`." + ) + }, + ) + low_cpu_mem_usage: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." + "When set to True, it will benefit LLM loading time and RAM consumption." + ) + }, + ) + load_meta_device: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to load the model to the device instead of the host, so it can reduce the host RAM usage." + "https://huggingface.co/blog/accelerate-large-models" + ) + }, + ) + do_image_splitting: bool = field(default=False, metadata={"help": "Whether to do image split during finetune."}) + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}, + ) + max_seq_length: Optional[int] = field( + default=512, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated." + }, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached preprocessed datasets or not."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + dataset_seed: int = field( + default=42, + metadata={ + "help": "Seed to use in dataset processing, different seeds might yield different datasets. This seed and the seed in training arguments are not related" + }, + ) + save_last_ckpt: bool = field( + default=True, metadata={"help": "Whether to save checkpoint at the end of the training."} + ) + input_column_names: List[str] = field( + default_factory=lambda: None, + metadata={ + "help": "Name of the column in the dataset that optionally provides context or input for the task. By " + "default, 'image,query' columns are used" + }, + ) + output_column_names: List[str] = field( + default_factory=lambda: None, + metadata={ + "help": "Name of the column in the dataset with the answer to the instruction. By default, the " + "'answers' column is used" + }, + ) + + +@dataclass +class FinetuneArguments: + """ + Arguments of finetune we are going to apply on the model. + """ + + lora_rank: int = field( + default=8, + metadata={"help": "Rank parameter in the LoRA method."}, + ) + lora_alpha: int = field( + default=8, + metadata={"help": "Alpha parameter in the LoRA method."}, + ) + lora_dropout: float = field( + default=0.1, + metadata={"help": "Dropout parameter in the LoRA method."}, + ) + lora_target_modules: str = field( + default=None, + metadata={"help": "Target modules for the LoRA/AdaLoRA method."}, + ) + +class LLavaDataCollator: + def __init__(self, processor, max_seq_length): + self.processor = processor + + num_image_tokens = (self.processor.image_processor.crop_size["height"] // self.processor.patch_size) \ + * (self.processor.image_processor.crop_size["width"] // self.processor.patch_size) + 1 + if self.processor.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + # text length + image length + self.max_seq_length = max_seq_length + num_image_tokens + + def __call__(self, examples): + texts = [] + images = [] + + keys = list(examples[0].keys()) + if not all(key in ["image", "query", "answers"] for key in keys): + raise ValueError("Unsupported dataset format") + for example in examples: + image = example["image"] + question = example["query"]["en"] + answer = random.choice(example["answers"]) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Answer briefly."}, + {"type": "image"}, + {"type": "text", "text": question}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": answer}]}, + ] + text = self.processor.apply_chat_template(messages, add_generation_prompt=False) + texts.append(text.strip()) + images.append(image) + + batch = self.processor(images, texts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.max_seq_length + ) + + labels = batch["input_ids"].clone() + if self.processor.tokenizer.pad_token_id is not None: + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + batch["labels"] = labels + + return batch + +def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length): + from tqdm import tqdm + + answers_unique = [] + generated_texts_unique = [] + + for i in tqdm(range(0, len(dataset), batch_size)): + examples = dataset[i : i + batch_size] + answers_unique.extend(examples["answers"]) + images = [im for im in examples["image"]] + texts = [] + for q in examples["query"]: + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Answer briefly."}, + {"type": "image"}, + {"type": "text", "text": q["en"]}, + ], + } + ] + text = processor.apply_chat_template(messages, add_generation_prompt=True) + texts.append(text.strip()) + inputs = processor( + images, + texts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_seq_length, + padding_side="left" + ) + inputs = {k: v.to("hpu") for k, v in inputs.items()} + generated_ids = model.generate( + **inputs, max_new_tokens=64, ignore_eos=False, lazy_mode=use_lazy_mode, hpu_graphs=use_hpu_graphs + ) + generated_texts = processor.batch_decode( + generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True + ) + generated_texts_unique.extend(generated_texts) + generated_texts_unique = [g.strip().strip(".") for g in generated_texts_unique] + anls = average_normalized_levenshtein_similarity( + ground_truth=answers_unique, + predicted_answers=generated_texts_unique, + ) + return anls + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def main(): + parser = HfArgumentParser((ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args, finetune_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + ( + model_args, + data_args, + training_args, + finetune_args, + ) = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + do_image_splitting=model_args.do_image_splitting, + padding_side="right", + ) + + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "trust_remote_code": True if model_args.trust_remote_code else None, + "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache, + "token": model_args.token, + } + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + raise ValueError("Please provide value for model_name_or_path or config_name.") + + setattr(processor, "patch_size", config.vision_config.patch_size) + setattr(processor, "vision_feature_select_strategy", config.vision_feature_select_strategy) + + # Load model + if model_args.model_name_or_path: + model_dtype = torch.bfloat16 if training_args.bf16 else None + model = AutoModelForVision2Seq.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + trust_remote_code=True if model_args.trust_remote_code else None, + torch_dtype=model_dtype, + low_cpu_mem_usage=model_args.low_cpu_mem_usage, + device_map=training_args.device.type if model_args.load_meta_device else None, + token=model_args.token, + ) + else: + raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.") + + if finetune_args.lora_target_modules is None: + target_modules = find_all_linear_names(model) + else: + target_modules = finetune_args.lora_target_modules + + lora_config = LoraConfig( + r=finetune_args.lora_rank, + lora_alpha=finetune_args.lora_alpha, + lora_dropout=finetune_args.lora_dropout, + target_modules=target_modules, + init_lora_weights="gaussian", + ) + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + train_dataset = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + split="train", + ) + + train_dataset = train_dataset.remove_columns( + [ + col + for col in train_dataset.column_names + if col not in (data_args.input_column_names + data_args.output_column_names) + ] + ) + + eval_dataset = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + split="test", + ) + + eval_dataset = eval_dataset.remove_columns( + [ + col + for col in eval_dataset.column_names + if col not in (data_args.input_column_names + data_args.output_column_names) + ] + ) + + data_collator = LLavaDataCollator(processor, max_seq_length=data_args.max_seq_length) + + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + + trainer = GaudiTrainer( + model=model, + args=training_args, + gaudi_config=gaudi_config, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + + if training_args.do_train: + train_result = trainer.train() + trainer.save_model() + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + print("start evaluation....................................................") + + if is_main_process(training_args.local_rank): + setattr(processor, "patch_size", None) + setattr(processor, "vision_feature_select_strategy", None) + processor.tokenizer.padding_side = "left" + + example = eval_dataset[15] + model.eval() + model = model.merge_and_unload() + if model_dtype == torch.bfloat16: + model = model.to(torch.bfloat16) + + image = example["image"] + query = example["query"] + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Answer briefly."}, + {"type": "image"}, + {"type": "text", "text": query["en"]}, + ], + } + ] + text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor( + [image], + [text.strip()], + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=data_args.max_seq_length, + padding_side="left" + ) + inputs = {k: v.to("hpu") for k, v in inputs.items()} + generated_ids = model.generate( + **inputs, + max_new_tokens=64, + ignore_eos=False, + lazy_mode=training_args.use_lazy_mode, + hpu_graphs=training_args.use_hpu_graphs_for_inference, + ) + generated_texts = processor.batch_decode( + generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True + ) + logger.info(f"generated: {generated_texts}") + if training_args.do_eval: + if training_args.use_hpu_graphs_for_inference: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + model = wrap_in_hpu_graph(model) + + anls = eval( + processor=processor, + model=model, + dataset=eval_dataset, + batch_size=training_args.per_device_eval_batch_size, + use_lazy_mode=training_args.use_lazy_mode, + use_hpu_graphs=training_args.use_hpu_graphs_for_inference, + max_seq_length=data_args.max_seq_length, + ) + eval_metrics = {"eval_accuracy": anls} + trainer.log_metrics("eval", eval_metrics) + trainer.save_metrics("eval", eval_metrics) + + +if __name__ == "__main__": + main() diff --git a/optimum/habana/transformers/models/llava/modeling_llava.py b/optimum/habana/transformers/models/llava/modeling_llava.py index 997c16d700..c94f5a24da 100644 --- a/optimum/habana/transformers/models/llava/modeling_llava.py +++ b/optimum/habana/transformers/models/llava/modeling_llava.py @@ -22,6 +22,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn as nn from transformers.cache_utils import Cache from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration from transformers.utils import logging @@ -136,49 +137,52 @@ def forward( - add new args tokens_pos """ - if token_idx is not None: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_feature_layer = ( - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer - ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + # 1. Extra the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_features = None + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1: + image_outputs = self.vision_tower( + pixel_values, + output_hidden_states=True, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] - # 1. Extra the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - - image_features = None - # 2. Merge text and images - if pixel_values is not None and input_ids.shape[1] != 1: - image_outputs = self.vision_tower( - pixel_values, - output_hidden_states=True, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" ) - # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. - selected_image_feature = image_outputs.hidden_states[vision_feature_layer] - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise ValueError( - f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" - ) + image_features = self.multi_modal_projector(selected_image_feature) + inputs_embeds = _merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, self.config.image_token_index + ) - image_features = self.multi_modal_projector(selected_image_feature) - inputs_embeds = _merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, self.config.image_token_index - ) + + + if token_idx is not None: outputs = self.language_model( attention_mask=attention_mask, @@ -220,20 +224,53 @@ def forward( ) else: - return super().forward( - input_ids=input_ids, - pixel_values=pixel_values, + + outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, + # TODO: from Transformers v4.45, `generate` sets `num_logits_to_keep` to 1 if not given, which we don't want here + # num_logits_to_keep=num_logits_to_keep, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + # print(loss) + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, ) def prepare_inputs_for_generation( From 9e22eb4a50fc84019589b183de2d8af96ff1ca83 Mon Sep 17 00:00:00 2001 From: lkk Date: Mon, 2 Dec 2024 06:21:17 +0000 Subject: [PATCH 2/5] make style --- .../image-to-text/run_llava_lora_finetune.py | 33 ++++++++++--------- .../models/llava/modeling_llava.py | 8 +---- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/examples/image-to-text/run_llava_lora_finetune.py b/examples/image-to-text/run_llava_lora_finetune.py index 528b5f61a0..e38603b086 100644 --- a/examples/image-to-text/run_llava_lora_finetune.py +++ b/examples/image-to-text/run_llava_lora_finetune.py @@ -249,12 +249,14 @@ class FinetuneArguments: metadata={"help": "Target modules for the LoRA/AdaLoRA method."}, ) + class LLavaDataCollator: def __init__(self, processor, max_seq_length): self.processor = processor - num_image_tokens = (self.processor.image_processor.crop_size["height"] // self.processor.patch_size) \ - * (self.processor.image_processor.crop_size["width"] // self.processor.patch_size) + 1 + num_image_tokens = (self.processor.image_processor.crop_size["height"] // self.processor.patch_size) * ( + self.processor.image_processor.crop_size["width"] // self.processor.patch_size + ) + 1 if self.processor.vision_feature_select_strategy == "default": num_image_tokens -= 1 @@ -287,11 +289,8 @@ def __call__(self, examples): texts.append(text.strip()) images.append(image) - batch = self.processor(images, texts, - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=self.max_seq_length + batch = self.processor( + images, texts, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_seq_length ) labels = batch["input_ids"].clone() @@ -301,6 +300,7 @@ def __call__(self, examples): return batch + def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length): from tqdm import tqdm @@ -310,7 +310,9 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m for i in tqdm(range(0, len(dataset), batch_size)): examples = dataset[i : i + batch_size] answers_unique.extend(examples["answers"]) - images = [im for im in examples["image"]] + images = [] + for im in examples["image"]: + images.append(im) texts = [] for q in examples["query"]: messages = [ @@ -332,7 +334,7 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m padding="max_length", truncation=True, max_length=max_seq_length, - padding_side="left" + padding_side="left", ) inputs = {k: v.to("hpu") for k, v in inputs.items()} generated_ids = model.generate( @@ -349,19 +351,20 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m ) return anls + def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() - multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"] for name, module in model.named_modules(): if any(mm_keyword in name for mm_keyword in multimodal_keywords): continue if isinstance(module, cls): - names = name.split('.') + names = name.split(".") lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - if 'lm_head' in lora_module_names: # needed for 16-bit - lora_module_names.remove('lm_head') + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") return list(lora_module_names) @@ -417,7 +420,7 @@ def main(): setattr(processor, "patch_size", config.vision_config.patch_size) setattr(processor, "vision_feature_select_strategy", config.vision_feature_select_strategy) - # Load model + # Load model if model_args.model_name_or_path: model_dtype = torch.bfloat16 if training_args.bf16 else None model = AutoModelForVision2Seq.from_pretrained( @@ -539,7 +542,7 @@ def main(): padding="max_length", truncation=True, max_length=data_args.max_seq_length, - padding_side="left" + padding_side="left", ) inputs = {k: v.to("hpu") for k, v in inputs.items()} generated_ids = model.generate( diff --git a/optimum/habana/transformers/models/llava/modeling_llava.py b/optimum/habana/transformers/models/llava/modeling_llava.py index c94f5a24da..eb18e85e02 100644 --- a/optimum/habana/transformers/models/llava/modeling_llava.py +++ b/optimum/habana/transformers/models/llava/modeling_llava.py @@ -171,19 +171,14 @@ def forward( elif vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature else: - raise ValueError( - f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" - ) + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") image_features = self.multi_modal_projector(selected_image_feature) inputs_embeds = _merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, self.config.image_token_index ) - - if token_idx is not None: - outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, @@ -224,7 +219,6 @@ def forward( ) else: - outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, From e0be37fd3076699acc8b8d3782562b06c897ed65 Mon Sep 17 00:00:00 2001 From: lkk Date: Wed, 4 Dec 2024 07:36:08 +0000 Subject: [PATCH 3/5] for transformers==v4.45.2. --- .../models/llava/modeling_llava.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/optimum/habana/transformers/models/llava/modeling_llava.py b/optimum/habana/transformers/models/llava/modeling_llava.py index eb18e85e02..caad37d9b1 100644 --- a/optimum/habana/transformers/models/llava/modeling_llava.py +++ b/optimum/habana/transformers/models/llava/modeling_llava.py @@ -130,7 +130,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: """ - Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llava/modeling_llava.py + Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/llava/modeling_llava.py#L362 The only differences are: - add new args token_idx - add new args image_offset @@ -151,8 +151,27 @@ def forward( else self.config.vision_feature_select_strategy ) - # 1. Extra the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + legacy_processing = False + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing + # not very reliable, but we don't expect one to actually pass 500+ images for one prompt + # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True + legacy_processing = ( + (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length + ) or (input_ids.shape[-1] == 1 and pixel_values is not None) + image_features = None # 2. Merge text and images @@ -256,7 +275,6 @@ def forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - # print(loss) return LlavaCausalLMOutputWithPast( loss=loss, From 1ff6602d44c725556f8f7a3cd0549ffb27fa3684 Mon Sep 17 00:00:00 2001 From: lkk Date: Wed, 4 Dec 2024 07:39:58 +0000 Subject: [PATCH 4/5] for transformers==v4.45.2. --- optimum/habana/transformers/models/llava/modeling_llava.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/optimum/habana/transformers/models/llava/modeling_llava.py b/optimum/habana/transformers/models/llava/modeling_llava.py index caad37d9b1..cd5fb17d79 100644 --- a/optimum/habana/transformers/models/llava/modeling_llava.py +++ b/optimum/habana/transformers/models/llava/modeling_llava.py @@ -208,8 +208,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - # TODO: from Transformers v4.45, `generate` sets `num_logits_to_keep` to 1 if not given, which we don't want here - # num_logits_to_keep=num_logits_to_keep, + num_logits_to_keep=num_logits_to_keep, token_idx=token_idx + image_offset, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, @@ -248,8 +247,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - # TODO: from Transformers v4.45, `generate` sets `num_logits_to_keep` to 1 if not given, which we don't want here - # num_logits_to_keep=num_logits_to_keep, + num_logits_to_keep=num_logits_to_keep, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, ) From 9a547b5e53c2fc09d1ef3b5ba5f4e0a89192f9e6 Mon Sep 17 00:00:00 2001 From: lkk Date: Fri, 6 Dec 2024 10:10:49 +0000 Subject: [PATCH 5/5] merge two scripts. --- examples/image-to-text/README.md | 4 +- .../run_image2text_lora_finetune.py | 175 +++++- .../image-to-text/run_llava_lora_finetune.py | 580 ------------------ 3 files changed, 148 insertions(+), 611 deletions(-) delete mode 100644 examples/image-to-text/run_llava_lora_finetune.py diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 45d24c3be3..4f578b6282 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -375,10 +375,10 @@ python3 ../gaudi_spawn.py \ --lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"' ``` -Here are single-/multi-device command examples for llava-hf/llava-1.5-7b-hf. +Here are single card training command examples for llava-hf/llava-1.5-7b-hf. ``` -python3 run_llava_lora_finetune.py \ +python3 run_image2text_lora_finetune.py \ --model_name_or_path llava-hf/llava-1.5-7b-hf \ --dataset_name nielsr/docvqa_1200_examples \ --bf16 True \ diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py index ded60e6d52..b86247376f 100644 --- a/examples/image-to-text/run_image2text_lora_finetune.py +++ b/examples/image-to-text/run_image2text_lora_finetune.py @@ -297,8 +297,58 @@ def __call__(self, examples): return batch +class LLavaDataCollator: + def __init__(self, processor, max_seq_length): + self.processor = processor + + num_image_tokens = (self.processor.image_processor.crop_size["height"] // self.processor.patch_size) * ( + self.processor.image_processor.crop_size["width"] // self.processor.patch_size + ) + 1 + if self.processor.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + # text length + image length + self.max_seq_length = max_seq_length + num_image_tokens + + def __call__(self, examples): + texts = [] + images = [] + + keys = list(examples[0].keys()) + if not all(key in ["image", "query", "answers"] for key in keys): + raise ValueError("Unsupported dataset format") + for example in examples: + image = example["image"] + question = example["query"]["en"] + answer = random.choice(example["answers"]) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Answer briefly."}, + {"type": "image"}, + {"type": "text", "text": question}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": answer}]}, + ] + text = self.processor.apply_chat_template(messages, add_generation_prompt=False) + texts.append(text.strip()) + images.append(image) + + batch = self.processor( + images, texts, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_seq_length + ) + + labels = batch["input_ids"].clone() + if self.processor.tokenizer.pad_token_id is not None: + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + batch["labels"] = labels + + return batch -def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length): + +def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length, model_arc=""): from tqdm import tqdm answers_unique = [] @@ -307,7 +357,6 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m for i in tqdm(range(0, len(dataset), batch_size)): examples = dataset[i : i + batch_size] answers_unique.extend(examples["answers"]) - images = [[im] for im in examples["image"]] texts = [] for q in examples["query"]: messages = [ @@ -322,14 +371,31 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m ] text = processor.apply_chat_template(messages, add_generation_prompt=True) texts.append(text.strip()) - inputs = processor( - text=texts, - images=images, - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=max_seq_length, - ) + + if "Llava" in model_arc: + images = [] + for im in examples["image"]: + images.append(im) + + inputs = processor( + images, + texts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_seq_length, + padding_side="left", + ) + else: + images = [[im] for im in examples["image"]] + inputs = processor( + text=texts, + images=images, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_seq_length, + ) inputs = {k: v.to("hpu") for k, v in inputs.items()} generated_ids = model.generate( **inputs, max_new_tokens=64, ignore_eos=False, lazy_mode=use_lazy_mode, hpu_graphs=use_hpu_graphs @@ -346,6 +412,22 @@ def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, m return anls +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + return list(lora_module_names) + + def main(): parser = HfArgumentParser((ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): @@ -380,7 +462,7 @@ def main(): do_image_splitting=model_args.do_image_splitting, padding_side="right", ) - setattr(processor.image_processor, "pad_to_longest_edge", True) + config_kwargs = { "cache_dir": model_args.cache_dir, "revision": model_args.model_revision, @@ -395,7 +477,17 @@ def main(): else: raise ValueError("Please provide value for model_name_or_path or config_name.") - # Load model + model_arc = "" + if config.architectures is not None: + model_arc = config.architectures[0] + + if "Llava" in model_arc: + setattr(processor, "patch_size", config.vision_config.patch_size) + setattr(processor, "vision_feature_select_strategy", config.vision_feature_select_strategy) + else: + setattr(processor.image_processor, "pad_to_longest_edge", True) + + # Load model if model_args.model_name_or_path: model_dtype = torch.bfloat16 if training_args.bf16 else None model = AutoModelForVision2Seq.from_pretrained( @@ -413,11 +505,16 @@ def main(): else: raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.") + if finetune_args.lora_target_modules is None: + target_modules = find_all_linear_names(model) + else: + target_modules = finetune_args.lora_target_modules + lora_config = LoraConfig( r=finetune_args.lora_rank, lora_alpha=finetune_args.lora_alpha, lora_dropout=finetune_args.lora_dropout, - target_modules=finetune_args.lora_target_modules, + target_modules=target_modules, init_lora_weights="gaussian", ) model = get_peft_model(model, lora_config) @@ -456,15 +553,19 @@ def main(): if col not in (data_args.input_column_names + data_args.output_column_names) ] ) - if hasattr(config, "image_token_id"): - # idefics - image_token_id = config.image_token_id - elif hasattr(config, "image_token_index"): - # mllama - image_token_id = config.image_token_index + if "Llava" in model_arc: + data_collator = LLavaDataCollator(processor, max_seq_length=data_args.max_seq_length) else: - raise ValueError("Please provide value for image_token_id") - data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id) + if hasattr(config, "image_token_id"): + # idefics + image_token_id = config.image_token_id + elif hasattr(config, "image_token_index"): + # mllama + image_token_id = config.image_token_index + else: + raise ValueError("Please provide value for image_token_id") + + data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id) gaudi_config = GaudiConfig() gaudi_config.use_fused_adam = True @@ -509,14 +610,29 @@ def main(): } ] text = processor.apply_chat_template(messages, add_generation_prompt=True) - inputs = processor( - text=[text.strip()], - images=[image], - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=data_args.max_seq_length, - ) + + if "Llava" in model_arc: + # don't expand image_token_id + setattr(processor, "patch_size", None) + setattr(processor, "vision_feature_select_strategy", None) + inputs = processor( + [image], + [text.strip()], + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=data_args.max_seq_length, + padding_side="left", + ) + else: + inputs = processor( + text=[text.strip()], + images=[image], + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=data_args.max_seq_length, + ) inputs = {k: v.to("hpu") for k, v in inputs.items()} generated_ids = model.generate( **inputs, @@ -543,6 +659,7 @@ def main(): use_lazy_mode=training_args.use_lazy_mode, use_hpu_graphs=training_args.use_hpu_graphs_for_inference, max_seq_length=data_args.max_seq_length, + model_arc=model_arc ) eval_metrics = {"eval_accuracy": anls} trainer.log_metrics("eval", eval_metrics) diff --git a/examples/image-to-text/run_llava_lora_finetune.py b/examples/image-to-text/run_llava_lora_finetune.py deleted file mode 100644 index e38603b086..0000000000 --- a/examples/image-to-text/run_llava_lora_finetune.py +++ /dev/null @@ -1,580 +0,0 @@ -# Apache v2 license -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -lora fine tuning script for image-to-text case -Adapted from the following sources: -https://colab.research.google.com/drive/1rm3AGquGEYXfeeizE40bbDtcWh5S4Nlq?usp=sharing -""" - -import logging -import os -import random -import sys -from dataclasses import dataclass, field -from typing import List, Optional - -import Levenshtein -import torch -import transformers -from datasets import load_dataset -from peft import LoraConfig, get_peft_model -from transformers import ( - AutoConfig, - AutoModelForVision2Seq, - AutoProcessor, - HfArgumentParser, -) -from transformers.trainer_utils import is_main_process - -from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments - - -try: - from optimum.habana.utils import check_optimum_habana_min_version -except ImportError: - - def check_optimum_habana_min_version(*a, **b): - return () - - -os.environ["WANDB_DISABLED"] = "true" - -logger = logging.getLogger(__name__) - -# Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. -check_optimum_habana_min_version("1.10.0") - - -def normalized_levenshtein(s1, s2): - len_s1, len_s2 = len(s1), len(s2) - distance = Levenshtein.distance(s1, s2) - return distance / max(len_s1, len_s2) - - -def similarity_score(a_ij, o_q_i, tau=0.5): - nl = normalized_levenshtein(a_ij, o_q_i) - return 1 - nl if nl < tau else 0 - - -def average_normalized_levenshtein_similarity(ground_truth, predicted_answers): - assert len(ground_truth) == len(predicted_answers), "Length of ground_truth and predicted_answers must match." - - N = len(ground_truth) - total_score = 0 - - for i in range(N): - a_i = ground_truth[i] - o_q_i = predicted_answers[i] - if o_q_i == "": - print("Warning: Skipped an empty prediction.") - max_score = 0 - else: - max_score = max(similarity_score(a_ij, o_q_i) for a_ij in a_i) - - total_score += max_score - - return total_score / N - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/processor we are going to fine-tune, or train from scratch. - """ - - model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": "The model checkpoint for weights initialization." - "Don't set if you want to train a model from scratch." - }, - ) - config_name: Optional[str] = field( - default=None, - metadata={"help": "Pretrained config name or path if not the same as model_name"}, - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, - ) - token: Optional[str] = field( - default=None, - metadata={"help": "auth token for private models"}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, - ) - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - trust_remote_code: bool = field( - default=False, - metadata={ - "help": ( - "Whether to trust the execution of code from datasets/models defined on the Hub." - " This option should only be set to `True` for repositories you trust and in which you have read the" - " code, as it will execute code present on the Hub on your local machine." - ) - }, - ) - use_cache: bool = field( - default=True, - metadata={ - "help": ( - "Whether or not the model should return the last key/values attentions (not used by all models)." - "Only relevant if `config.is_decoder=True`." - ) - }, - ) - low_cpu_mem_usage: bool = field( - default=False, - metadata={ - "help": ( - "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." - "When set to True, it will benefit LLM loading time and RAM consumption." - ) - }, - ) - load_meta_device: bool = field( - default=False, - metadata={ - "help": ( - "It is an option to load the model to the device instead of the host, so it can reduce the host RAM usage." - "https://huggingface.co/blog/accelerate-large-models" - ) - }, - ) - do_image_splitting: bool = field(default=False, metadata={"help": "Whether to do image split during finetune."}) - - -@dataclass -class DataArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - - dataset_name: Optional[str] = field( - default=None, - metadata={"help": "The name of the dataset to use (via the datasets library)."}, - ) - dataset_config_name: Optional[str] = field( - default=None, - metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}, - ) - max_seq_length: Optional[int] = field( - default=512, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated." - }, - ) - overwrite_cache: bool = field( - default=False, - metadata={"help": "Overwrite the cached preprocessed datasets or not."}, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." - }, - ) - dataset_seed: int = field( - default=42, - metadata={ - "help": "Seed to use in dataset processing, different seeds might yield different datasets. This seed and the seed in training arguments are not related" - }, - ) - save_last_ckpt: bool = field( - default=True, metadata={"help": "Whether to save checkpoint at the end of the training."} - ) - input_column_names: List[str] = field( - default_factory=lambda: None, - metadata={ - "help": "Name of the column in the dataset that optionally provides context or input for the task. By " - "default, 'image,query' columns are used" - }, - ) - output_column_names: List[str] = field( - default_factory=lambda: None, - metadata={ - "help": "Name of the column in the dataset with the answer to the instruction. By default, the " - "'answers' column is used" - }, - ) - - -@dataclass -class FinetuneArguments: - """ - Arguments of finetune we are going to apply on the model. - """ - - lora_rank: int = field( - default=8, - metadata={"help": "Rank parameter in the LoRA method."}, - ) - lora_alpha: int = field( - default=8, - metadata={"help": "Alpha parameter in the LoRA method."}, - ) - lora_dropout: float = field( - default=0.1, - metadata={"help": "Dropout parameter in the LoRA method."}, - ) - lora_target_modules: str = field( - default=None, - metadata={"help": "Target modules for the LoRA/AdaLoRA method."}, - ) - - -class LLavaDataCollator: - def __init__(self, processor, max_seq_length): - self.processor = processor - - num_image_tokens = (self.processor.image_processor.crop_size["height"] // self.processor.patch_size) * ( - self.processor.image_processor.crop_size["width"] // self.processor.patch_size - ) + 1 - if self.processor.vision_feature_select_strategy == "default": - num_image_tokens -= 1 - - # text length + image length - self.max_seq_length = max_seq_length + num_image_tokens - - def __call__(self, examples): - texts = [] - images = [] - - keys = list(examples[0].keys()) - if not all(key in ["image", "query", "answers"] for key in keys): - raise ValueError("Unsupported dataset format") - for example in examples: - image = example["image"] - question = example["query"]["en"] - answer = random.choice(example["answers"]) - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Answer briefly."}, - {"type": "image"}, - {"type": "text", "text": question}, - ], - }, - {"role": "assistant", "content": [{"type": "text", "text": answer}]}, - ] - text = self.processor.apply_chat_template(messages, add_generation_prompt=False) - texts.append(text.strip()) - images.append(image) - - batch = self.processor( - images, texts, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_seq_length - ) - - labels = batch["input_ids"].clone() - if self.processor.tokenizer.pad_token_id is not None: - labels[labels == self.processor.tokenizer.pad_token_id] = -100 - batch["labels"] = labels - - return batch - - -def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length): - from tqdm import tqdm - - answers_unique = [] - generated_texts_unique = [] - - for i in tqdm(range(0, len(dataset), batch_size)): - examples = dataset[i : i + batch_size] - answers_unique.extend(examples["answers"]) - images = [] - for im in examples["image"]: - images.append(im) - texts = [] - for q in examples["query"]: - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Answer briefly."}, - {"type": "image"}, - {"type": "text", "text": q["en"]}, - ], - } - ] - text = processor.apply_chat_template(messages, add_generation_prompt=True) - texts.append(text.strip()) - inputs = processor( - images, - texts, - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=max_seq_length, - padding_side="left", - ) - inputs = {k: v.to("hpu") for k, v in inputs.items()} - generated_ids = model.generate( - **inputs, max_new_tokens=64, ignore_eos=False, lazy_mode=use_lazy_mode, hpu_graphs=use_hpu_graphs - ) - generated_texts = processor.batch_decode( - generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True - ) - generated_texts_unique.extend(generated_texts) - generated_texts_unique = [g.strip().strip(".") for g in generated_texts_unique] - anls = average_normalized_levenshtein_similarity( - ground_truth=answers_unique, - predicted_answers=generated_texts_unique, - ) - return anls - - -def find_all_linear_names(model): - cls = torch.nn.Linear - lora_module_names = set() - multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"] - for name, module in model.named_modules(): - if any(mm_keyword in name for mm_keyword in multimodal_keywords): - continue - if isinstance(module, cls): - names = name.split(".") - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - if "lm_head" in lora_module_names: # needed for 16-bit - lora_module_names.remove("lm_head") - return list(lora_module_names) - - -def main(): - parser = HfArgumentParser((ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, data_args, training_args, finetune_args = parser.parse_json_file( - json_file=os.path.abspath(sys.argv[1]) - ) - else: - ( - model_args, - data_args, - training_args, - finetune_args, - ) = parser.parse_args_into_dataclasses() - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) - - if is_main_process(training_args.local_rank): - transformers.utils.logging.set_verbosity_info() - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - processor = AutoProcessor.from_pretrained( - model_args.model_name_or_path, - do_image_splitting=model_args.do_image_splitting, - padding_side="right", - ) - - config_kwargs = { - "cache_dir": model_args.cache_dir, - "revision": model_args.model_revision, - "trust_remote_code": True if model_args.trust_remote_code else None, - "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache, - "token": model_args.token, - } - if model_args.config_name: - config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) - elif model_args.model_name_or_path: - config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) - else: - raise ValueError("Please provide value for model_name_or_path or config_name.") - - setattr(processor, "patch_size", config.vision_config.patch_size) - setattr(processor, "vision_feature_select_strategy", config.vision_feature_select_strategy) - - # Load model - if model_args.model_name_or_path: - model_dtype = torch.bfloat16 if training_args.bf16 else None - model = AutoModelForVision2Seq.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - trust_remote_code=True if model_args.trust_remote_code else None, - torch_dtype=model_dtype, - low_cpu_mem_usage=model_args.low_cpu_mem_usage, - device_map=training_args.device.type if model_args.load_meta_device else None, - token=model_args.token, - ) - else: - raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.") - - if finetune_args.lora_target_modules is None: - target_modules = find_all_linear_names(model) - else: - target_modules = finetune_args.lora_target_modules - - lora_config = LoraConfig( - r=finetune_args.lora_rank, - lora_alpha=finetune_args.lora_alpha, - lora_dropout=finetune_args.lora_dropout, - target_modules=target_modules, - init_lora_weights="gaussian", - ) - model = get_peft_model(model, lora_config) - model.print_trainable_parameters() - - train_dataset = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - token=model_args.token, - trust_remote_code=model_args.trust_remote_code, - split="train", - ) - - train_dataset = train_dataset.remove_columns( - [ - col - for col in train_dataset.column_names - if col not in (data_args.input_column_names + data_args.output_column_names) - ] - ) - - eval_dataset = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - token=model_args.token, - trust_remote_code=model_args.trust_remote_code, - split="test", - ) - - eval_dataset = eval_dataset.remove_columns( - [ - col - for col in eval_dataset.column_names - if col not in (data_args.input_column_names + data_args.output_column_names) - ] - ) - - data_collator = LLavaDataCollator(processor, max_seq_length=data_args.max_seq_length) - - gaudi_config = GaudiConfig() - gaudi_config.use_fused_adam = True - gaudi_config.use_fused_clip_norm = True - - trainer = GaudiTrainer( - model=model, - args=training_args, - gaudi_config=gaudi_config, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - ) - - if training_args.do_train: - train_result = trainer.train() - trainer.save_model() - metrics = train_result.metrics - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - print("start evaluation....................................................") - - if is_main_process(training_args.local_rank): - setattr(processor, "patch_size", None) - setattr(processor, "vision_feature_select_strategy", None) - processor.tokenizer.padding_side = "left" - - example = eval_dataset[15] - model.eval() - model = model.merge_and_unload() - if model_dtype == torch.bfloat16: - model = model.to(torch.bfloat16) - - image = example["image"] - query = example["query"] - - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Answer briefly."}, - {"type": "image"}, - {"type": "text", "text": query["en"]}, - ], - } - ] - text = processor.apply_chat_template(messages, add_generation_prompt=True) - inputs = processor( - [image], - [text.strip()], - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=data_args.max_seq_length, - padding_side="left", - ) - inputs = {k: v.to("hpu") for k, v in inputs.items()} - generated_ids = model.generate( - **inputs, - max_new_tokens=64, - ignore_eos=False, - lazy_mode=training_args.use_lazy_mode, - hpu_graphs=training_args.use_hpu_graphs_for_inference, - ) - generated_texts = processor.batch_decode( - generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True - ) - logger.info(f"generated: {generated_texts}") - if training_args.do_eval: - if training_args.use_hpu_graphs_for_inference: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - - model = wrap_in_hpu_graph(model) - - anls = eval( - processor=processor, - model=model, - dataset=eval_dataset, - batch_size=training_args.per_device_eval_batch_size, - use_lazy_mode=training_args.use_lazy_mode, - use_hpu_graphs=training_args.use_hpu_graphs_for_inference, - max_seq_length=data_args.max_seq_length, - ) - eval_metrics = {"eval_accuracy": anls} - trainer.log_metrics("eval", eval_metrics) - trainer.save_metrics("eval", eval_metrics) - - -if __name__ == "__main__": - main()