From f8ee1217657217a9bea56645a55cb33fad2fa6a3 Mon Sep 17 00:00:00 2001 From: ShinoharaHare Date: Tue, 24 Sep 2024 19:08:08 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=20LLaVa=20?= =?UTF-8?q?=E8=A8=93=E7=B7=B4=E6=94=AF=E6=8F=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/convert_to_hf.py | 15 ++- src/llm_training/data/__init__.py | 1 + .../visual_instruction_tuning/__init__.py | 6 ++ .../visual_instruction_tuning_datacollator.py | 68 +++++++++++++ .../visual_instruction_tuning_datamodule.py | 99 +++++++++++++++++++ ...al_instruction_tuning_datamodule_config.py | 25 +++++ src/llm_training/lms/clm/clm.py | 38 ++++--- src/llm_training/lms/protos/clm_proto.py | 14 ++- src/llm_training/models/__init__.py | 1 + src/llm_training/models/hf_llava/__init__.py | 2 + .../models/hf_llava/hf_llava_config.py | 4 + .../models/hf_llava/hf_llava_model.py | 91 +++++++++++++++++ src/llm_training/overrides/cli/utils.py | 8 +- 13 files changed, 356 insertions(+), 16 deletions(-) create mode 100644 src/llm_training/data/visual_instruction_tuning/__init__.py create mode 100644 src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datacollator.py create mode 100644 src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datamodule.py create mode 100644 src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datamodule_config.py create mode 100644 src/llm_training/models/hf_llava/__init__.py create mode 100644 src/llm_training/models/hf_llava/hf_llava_config.py create mode 100644 src/llm_training/models/hf_llava/hf_llava_model.py diff --git a/scripts/convert_to_hf.py b/scripts/convert_to_hf.py index 5dd6f6d..24b06dc 100644 --- a/scripts/convert_to_hf.py +++ b/scripts/convert_to_hf.py @@ -72,7 +72,15 @@ def main( print('Saving model') hf_model.to(dtype).save_pretrained(output_dir) - if isinstance(datamodule, (PreTrainingDataModule, InstructionTuningDataModule, PreferenceTuningDataModule)): + if isinstance( + datamodule, + ( + PreTrainingDataModule, + InstructionTuningDataModule, + PreferenceTuningDataModule, + VisualInstructionTuningDataModule + ) + ): print('Saving tokenizer') tokenizer = datamodule.config.tokenizer tokenizer.model_max_length = max(tokenizer.model_max_length, datamodule.config.max_length) @@ -80,6 +88,11 @@ def main( if chat_template is not None: tokenizer.chat_template = chat_template tokenizer.save_pretrained(output_dir) + + if isinstance(datamodule, VisualInstructionTuningDataModule): + print('Saving image processor') + image_processor = datamodule.config.image_processor + image_processor.save_pretrained(output_dir) def convert_checkpoint(path: Path) -> dict[str, Any]: diff --git a/src/llm_training/data/__init__.py b/src/llm_training/data/__init__.py index fd78ef0..f7e9fe1 100644 --- a/src/llm_training/data/__init__.py +++ b/src/llm_training/data/__init__.py @@ -5,3 +5,4 @@ from .instruction_tuning import * from .pre_training import * from .preference_tuning import * +from .visual_instruction_tuning import * diff --git a/src/llm_training/data/visual_instruction_tuning/__init__.py b/src/llm_training/data/visual_instruction_tuning/__init__.py new file mode 100644 index 0000000..59f0e41 --- /dev/null +++ b/src/llm_training/data/visual_instruction_tuning/__init__.py @@ -0,0 +1,6 @@ +from .visual_instruction_tuning_datacollator import \ + VisualInstructionTuningDataCollator +from .visual_instruction_tuning_datamodule import \ + VisualInstructionTuningDataModule +from .visual_instruction_tuning_datamodule_config import \ + VisualInstructionTuningDataModuleConfig diff --git a/src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datacollator.py b/src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datacollator.py new file mode 100644 index 0000000..76f879f --- /dev/null +++ b/src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datacollator.py @@ -0,0 +1,68 @@ +from typing import Any, TypeVar + +import torch + +from llm_training.data.base_datacollator import BaseDataCollator + +from .visual_instruction_tuning_datamodule_config import \ + VisualInstructionTuningDataModuleConfig + +T = TypeVar('T') + +class VisualInstructionTuningDataCollator(BaseDataCollator): + config: VisualInstructionTuningDataModuleConfig + + def __init__(self, config: VisualInstructionTuningDataModuleConfig): + super().__init__(config) + + assert 'pad_token' in config.tokenizer.special_tokens_map, '`pad_token` is not specified. Please set it manually.' + + def _pad_to_longest(self, batch: list[list[T]], padding_value: T) -> list[list[T]]: + n = max(len(y) for y in batch) + if self.config.pad_to_multiple_of is not None: + n = ((n // self.config.pad_to_multiple_of) + 1) * self.config.pad_to_multiple_of + + new_batch = [] + for x in batch: + num_paddings = n - len(x) + paddings = [padding_value] * num_paddings + x = paddings + x if self.config.tokenizer.padding_side == 'left' else x + paddings + new_batch.append(x) + + return new_batch + + def __call__(self, batch: list[dict[str, Any]]): + batch_input_ids = [] + batch_attention_mask = [] + batch_position_ids = [] + batch_labels = [] + batch_pixel_values = [] + + for x in batch: + input_ids = x['input_ids'] + n = len(input_ids) + + batch_input_ids.append(input_ids) + batch_attention_mask.append([1] * n) + batch_position_ids.append(list(range(n))) + batch_labels.append(x['labels']) + batch_pixel_values.append(x['pixel_values']) + + batch_input_ids = self._pad_to_longest(batch_input_ids, self.config.tokenizer.pad_token_id) + batch_attention_mask = self._pad_to_longest(batch_attention_mask, 0) + batch_position_ids = self._pad_to_longest(batch_position_ids, 0) + batch_labels = self._pad_to_longest(batch_labels, -100) + + input_ids = torch.tensor(batch_input_ids) + attention_mask = torch.tensor(batch_attention_mask) + position_ids = torch.tensor(batch_position_ids) + labels = torch.tensor(batch_labels) + pixel_values = torch.tensor(batch_pixel_values) + + return { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + 'labels': labels, + 'pixel_values': pixel_values + } diff --git a/src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datamodule.py b/src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datamodule.py new file mode 100644 index 0000000..4a50fdf --- /dev/null +++ b/src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datamodule.py @@ -0,0 +1,99 @@ +from typing import Any + +from transformers import BaseImageProcessor, PreTrainedTokenizerBase + +from llm_training.data.hf_based.hf_based_datamodule import (DatasetDict, + HFBasedDataModule) + +from .visual_instruction_tuning_datacollator import \ + VisualInstructionTuningDataCollator +from .visual_instruction_tuning_datamodule_config import \ + VisualInstructionTuningDataModuleConfig + + +class VisualInstructionTuningDataModule(HFBasedDataModule): + config: VisualInstructionTuningDataModuleConfig + datacollator_class = VisualInstructionTuningDataCollator + + def __init__(self, config: VisualInstructionTuningDataModuleConfig) -> None: + super().__init__(config) + + def pre_process_data(self, dataset_dict: DatasetDict) -> DatasetDict: + dataset_dict = self.map_dataset_dict( + dataset_dict, + _process_text_and_image, + remove_columns=True, + fn_kwargs=dict( + tokenizer=self.config.tokenizer, + chat_template=self.config.chat_template, + image_processor=self.config.image_processor + ), + num_proc=self.config.num_proc, + desc='Process text and image' + ) + + dataset_dict = dataset_dict.filter( + _drop_overlong, + input_columns='input_ids', + fn_kwargs=dict(max_length=self.config.max_length), + num_proc=self.config.num_proc, + desc='Drop overlong' + ) + + return dataset_dict + + +def _process_text_and_image( + example: dict[str, Any], + tokenizer: PreTrainedTokenizerBase, + chat_template: str | None, + image_processor: BaseImageProcessor +): + messages = example['messages'] + image = example['image'] + + input_ids = [] + labels = [] + + system_prompt = None + if messages[0]['role'] == 'system': + system_prompt = messages.pop(0) + + for i, message in enumerate(messages): + conversation = [message] + if i == 0 and system_prompt is not None: + conversation.insert(0, system_prompt) + text = tokenizer.apply_chat_template( + conversation, + chat_template=chat_template, + tokenize=False, + add_generation_prompt=message['role'] == 'user', + index=i, + length=len(messages) + ) + # 這裡將同一筆資料分多次 tokenize,為保證跟一次 tokenize 全部的結果相同 + # 先在前面加一個 token,encode 後再移除掉 + text = tokenizer.bos_token + text + current_input_ids = tokenizer.encode(text, add_special_tokens=False) + current_input_ids = current_input_ids[1:] + + if message['role'] in ['system', 'user']: + input_ids += current_input_ids + labels += [-100] * len(current_input_ids) + elif message['role'] == 'assistant': + input_ids += current_input_ids + labels += current_input_ids + else: + raise ValueError(f"Unknown role: `{message['role']}`") + + pixel_values = image_processor(image).pixel_values[0] + + return { + 'input_ids': input_ids, + 'labels': labels, + 'pixel_values': pixel_values + } + + +def _drop_overlong(input_ids: list[int], max_length: int): + return len(input_ids) <= max_length diff --git a/src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datamodule_config.py b/src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datamodule_config.py new file mode 100644 index 0000000..99d2ef3 --- /dev/null +++ b/src/llm_training/data/visual_instruction_tuning/visual_instruction_tuning_datamodule_config.py @@ -0,0 +1,25 @@ +import logging + +from pydantic import field_validator +from transformers import BaseImageProcessor, PreTrainedTokenizerBase + +from llm_training.data.chat_templates import get_chat_template +from llm_training.data.hf_based.hf_based_datamodule_config import \ + HFBasedDataModuleConfig + +logger = logging.getLogger(__name__) + + +class VisualInstructionTuningDataModuleConfig(HFBasedDataModuleConfig): + tokenizer: PreTrainedTokenizerBase + image_processor: BaseImageProcessor + chat_template: str | None = None + max_length: int | None = None + pad_to_multiple_of: int | None = None + + @field_validator('chat_template') + @classmethod + def validate_chat_template(cls, value: str | None) -> str | None: + if value is not None: + value = get_chat_template(value) + return value diff --git a/src/llm_training/lms/clm/clm.py b/src/llm_training/lms/clm/clm.py index 2c2db6a..17bd183 100644 --- a/src/llm_training/lms/clm/clm.py +++ b/src/llm_training/lms/clm/clm.py @@ -91,17 +91,34 @@ def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tens reduction='mean' ) - def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> torch.Tensor: - labels = shift_labels(batch['labels'], self.config.ignore_index) + def forward_batch(self, batch: dict[str, torch.Tensor | Any]) -> tuple[torch.Tensor, torch.Tensor]: + kwargs = dict( + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'] + ) + + if 'position_ids' in batch: + kwargs['position_ids'] = batch['position_ids'] + + labels = batch['labels'] + if 'pixel_values' in batch: + kwargs['labels'] = labels + kwargs['pixel_values'] = batch['pixel_values'] + attention_mask, labels, logits = self.model(**kwargs) + attention_mask = shift_labels(attention_mask, 0) + labels = shift_labels(labels, self.config.ignore_index) + labels[attention_mask == 0] = self.config.ignore_index + else: + logits = self.model(**kwargs) + labels = shift_labels(labels, self.config.ignore_index) + + return logits, labels + def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> torch.Tensor: if self.config.neftune_alpha is not None: self._current_attention_mask = batch['attention_mask'] - logits = self.model( - input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - position_ids=batch.get('position_ids', None) - ) + logits, labels = self.forward_batch(batch) if self.config.neftune_alpha is not None: self.log('NEFTune Alpha', self.config.neftune_alpha) @@ -128,12 +145,7 @@ def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> def validation_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int, dataloader_idx: int = 0): batch_size = batch['input_ids'].size(0) - labels = shift_labels(batch['labels'], self.config.ignore_index) - logits = self.model( - input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - position_ids=batch.get('position_ids', None) - ) + logits, labels = self.forward_batch(batch) self.val_perplexity.update(logits, labels) loss = self.compute_loss(logits, labels) diff --git a/src/llm_training/lms/protos/clm_proto.py b/src/llm_training/lms/protos/clm_proto.py index 58ac91b..859cbf9 100644 --- a/src/llm_training/lms/protos/clm_proto.py +++ b/src/llm_training/lms/protos/clm_proto.py @@ -1,4 +1,4 @@ -from typing import Protocol +from typing import Protocol, overload import torch from torch import nn @@ -8,6 +8,7 @@ def get_input_embeddings(self) -> nn.Embedding: ... def get_output_embeddings(self) -> nn.Linear: ... + @overload def __call__( self, *, @@ -17,3 +18,14 @@ def __call__( input_embeds: torch.Tensor | None = None ) -> torch.Tensor: ... + @overload + def __call__( + self, + *, + input_ids: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + input_embeds: torch.Tensor | None = None, + labels: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... diff --git a/src/llm_training/models/__init__.py b/src/llm_training/models/__init__.py index 1d56192..cd0af90 100644 --- a/src/llm_training/models/__init__.py +++ b/src/llm_training/models/__init__.py @@ -1,5 +1,6 @@ from .base_model import * from .hf_causal_lm import * from .hf_compat_model import * +from .hf_llava import * from .llama import * from .phi3 import * diff --git a/src/llm_training/models/hf_llava/__init__.py b/src/llm_training/models/hf_llava/__init__.py new file mode 100644 index 0000000..0f5331c --- /dev/null +++ b/src/llm_training/models/hf_llava/__init__.py @@ -0,0 +1,2 @@ +from .hf_llava_config import HFLlavaConfig +from .hf_llava_model import HFLlava diff --git a/src/llm_training/models/hf_llava/hf_llava_config.py b/src/llm_training/models/hf_llava/hf_llava_config.py new file mode 100644 index 0000000..93c9cbb --- /dev/null +++ b/src/llm_training/models/hf_llava/hf_llava_config.py @@ -0,0 +1,4 @@ +from llm_training.models.hf_compat_model import HFCompatModelConfig + +class HFLlavaConfig(HFCompatModelConfig): + enable_gradient_checkpointing: bool = False diff --git a/src/llm_training/models/hf_llava/hf_llava_model.py b/src/llm_training/models/hf_llava/hf_llava_model.py new file mode 100644 index 0000000..154819f --- /dev/null +++ b/src/llm_training/models/hf_llava/hf_llava_model.py @@ -0,0 +1,91 @@ +import torch +from torch import nn +from transformers import LlavaConfig, LlavaForConditionalGeneration + +from llm_training.models.hf_compat_model import HFCompatModel +from llm_training.utils.decorators import copy_method_signature + +from .hf_llava_config import HFLlavaConfig + + +class HFLlava(HFCompatModel): + config: HFLlavaConfig + hf_config: LlavaConfig + hf_model: LlavaForConditionalGeneration + + config_class = HFLlavaConfig + hf_config_class = LlavaConfig + hf_model_class = LlavaForConditionalGeneration + + @property + def no_split_modules(self) -> list[str] | None: + return self.hf_model._no_split_modules + + def __init__(self, config: HFLlavaConfig) -> None: + super().__init__(config) + + self.hf_model = self.construct_hf_model() + + if self.config.enable_gradient_checkpointing: + self.hf_model.gradient_checkpointing_enable({'use_reentrant': False}) + + def convert_state_dict_from_hf(self, hf_state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return {'hf_model.' + k: v for k, v in hf_state_dict.items()} + + def convert_state_dict_to_hf(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return {k.removeprefix('hf_model.'): v for k, v in state_dict.items()} + + def forward( + self, + input_ids: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + input_embeds: torch.Tensor | None = None, + labels: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + vision_feature_layer = self.hf_config.vision_feature_layer + vision_feature_select_strategy = self.hf_config.vision_feature_select_strategy + + if input_embeds is None: + # 1. Extra the input embeddings + input_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1: + image_outputs = self.hf_model.vision_tower(pixel_values, output_hidden_states=True) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.hf_config.vision_feature_select_strategy}" + ) + + image_features = self.hf_model.multi_modal_projector(selected_image_feature) + input_embeds = input_embeds.to(image_features.dtype) + input_embeds, attention_mask, labels, position_ids = self.hf_model._merge_input_ids_with_image_features( + image_features, input_embeds, input_ids, attention_mask, labels + ) + + outputs = self.hf_model.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=input_embeds + ) + + logits = outputs[0] + return attention_mask, labels, logits + + @copy_method_signature(forward) + def __call__(): ... + + def get_input_embeddings(self) -> nn.Embedding: + return self.hf_model.get_input_embeddings() + + def get_output_embeddings(self) -> nn.Linear: + return self.hf_model.get_output_embeddings() diff --git a/src/llm_training/overrides/cli/utils.py b/src/llm_training/overrides/cli/utils.py index 01fd776..8887726 100644 --- a/src/llm_training/overrides/cli/utils.py +++ b/src/llm_training/overrides/cli/utils.py @@ -1,5 +1,6 @@ from jsonargparse import class_from_function -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import (AutoImageProcessor, AutoTokenizer, + BaseImageProcessor, PreTrainedTokenizerBase) def _load_tokenizer(path: str, pad_token: str | None = None, **kwargs) -> PreTrainedTokenizerBase: @@ -8,4 +9,9 @@ def _load_tokenizer(path: str, pad_token: str | None = None, **kwargs) -> PreTra return AutoTokenizer.from_pretrained(path, **kwargs) +def _load_image_processor(path: str, pad_token: str | None = None, **kwargs) -> BaseImageProcessor: + return AutoImageProcessor.from_pretrained(path, **kwargs) + + HFTokenizer = class_from_function(_load_tokenizer, name='HFTokenizer') +HFImageProcessor = class_from_function(_load_image_processor, name='HFImageProcessor') From 0c435d88bdb1ba13493f52e60eeb7ad3729c6a2b Mon Sep 17 00:00:00 2001 From: ShinoharaHare Date: Tue, 24 Sep 2024 19:09:12 +0800 Subject: [PATCH 2/2] =?UTF-8?q?docs:=20=E6=96=B0=E5=A2=9E=20LLaVa=20?= =?UTF-8?q?=E7=AF=84=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../llava/llava-phi-3.5_ft_example.yaml | 74 +++++++++++++++++++ .../llava/llava-phi-3.5_pt_example.yaml | 74 +++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 config/examples/llava/llava-phi-3.5_ft_example.yaml create mode 100644 config/examples/llava/llava-phi-3.5_pt_example.yaml diff --git a/config/examples/llava/llava-phi-3.5_ft_example.yaml b/config/examples/llava/llava-phi-3.5_ft_example.yaml new file mode 100644 index 0000000..d21b612 --- /dev/null +++ b/config/examples/llava/llava-phi-3.5_ft_example.yaml @@ -0,0 +1,74 @@ +seed_everything: 42 +float32_matmul_precision: medium +logging_level: DEBUG + +trainer: + strategy: + class_path: llm_training.overrides.DeepSpeedStrategy + init_args: + stage: 2 + precision: bf16-true + logger: + class_path: llm_training.overrides.wandb.WandbLogger + init_args: + name: llava-phi-3.5_ft_example + job_type: example + project: llm-training + save_dir: logs + save_code: true + max_epochs: 1 + val_check_interval: null + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + callbacks: + - class_path: LearningRateMonitor + - class_path: llm_training.overrides.ModelCheckpoint + init_args: + save_on_train_epoch_end: true + save_top_k: 1 + +model: + class_path: llm_training.lms.CLM + init_args.config: + model: + model_class: llm_training.models.HFLlava + model_config: + hf_path: + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + enable_gradient_checkpointing: true + + frozen_modules: + - vision_tower$ + + # neftune_alpha: 5.0 + + optim: + optimizer_class: deepspeed.ops.adam.FusedAdam + optimizer_kwargs: + lr: 2e-5 + lr_scheduler_class: llm_training.lr_schedulers.CosineAnnealingWarmupLR + lr_scheduler_kwargs: + num_warmup_steps: 100 + min_lr: 2e-6 + +data: + class_path: llm_training.data.VisualInstructionTuningDataModule + init_args.config: + dataset_kwargs: + path: ShinoharaHare/LLaVA-NeXT-Data-Reformatted + name: 10K + tokenizer: + class_path: HFTokenizer + init_args.path: ShinoharaHare/llava-phi-3.5-mini-instruct_untrained + chat_template: phi-3 + image_processor: + class_path: HFImageProcessor + init_args.path: ShinoharaHare/llava-phi-3.5-mini-instruct_untrained + batch_size: 1 + max_length: 4096 + pad_to_multiple_of: 64 + validation_split: null + num_proc: 4 + num_workers: 4 + enable_cache: true diff --git a/config/examples/llava/llava-phi-3.5_pt_example.yaml b/config/examples/llava/llava-phi-3.5_pt_example.yaml new file mode 100644 index 0000000..add985a --- /dev/null +++ b/config/examples/llava/llava-phi-3.5_pt_example.yaml @@ -0,0 +1,74 @@ +seed_everything: 42 +float32_matmul_precision: medium +logging_level: DEBUG + +trainer: + strategy: + class_path: llm_training.overrides.DeepSpeedStrategy + init_args: + stage: 2 + exclude_frozen_parameters: true + precision: bf16-true + logger: + class_path: llm_training.overrides.wandb.WandbLogger + init_args: + name: llava-phi-3.5_pt_example + job_type: example + project: llm-training + save_dir: logs + save_code: true + max_epochs: 1 + val_check_interval: null + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + callbacks: + - class_path: LearningRateMonitor + - class_path: llm_training.overrides.ModelCheckpoint + init_args: + save_on_train_epoch_end: true + save_top_k: 1 + +model: + class_path: llm_training.lms.CLM + init_args.config: + model: + model_class: llm_training.models.HFLlava + model_config: + hf_path: ShinoharaHare/llava-phi-3.5-mini-instruct_untrained + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + enable_gradient_checkpointing: true + + frozen_modules: + - vision_tower$ + - language_model$ + + # neftune_alpha: 5.0 + + optim: + optimizer_class: deepspeed.ops.adam.FusedAdam + optimizer_kwargs: + lr: 1e-3 + lr_scheduler_kwargs: + num_warmup_steps: 100 + +data: + class_path: llm_training.data.VisualInstructionTuningDataModule + init_args.config: + dataset_kwargs: + path: ShinoharaHare/LLaVA-NeXT-Data-Reformatted + name: 10K + tokenizer: + class_path: HFTokenizer + init_args.path: ShinoharaHare/llava-phi-3.5-mini-instruct_untrained + chat_template: phi-3 + image_processor: + class_path: HFImageProcessor + init_args.path: ShinoharaHare/llava-phi-3.5-mini-instruct_untrained + batch_size: 1 + max_length: 4096 + pad_to_multiple_of: 64 + validation_split: null + num_proc: 4 + num_workers: 4 + enable_cache: true