Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

新增 LLaVa 支援 #11

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions config/examples/llava/llava-phi-3.5_ft_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
seed_everything: 42
float32_matmul_precision: medium
logging_level: DEBUG

trainer:
strategy:
class_path: llm_training.overrides.DeepSpeedStrategy
init_args:
stage: 2
precision: bf16-true
logger:
class_path: llm_training.overrides.wandb.WandbLogger
init_args:
name: llava-phi-3.5_ft_example
job_type: example
project: llm-training
save_dir: logs
save_code: true
max_epochs: 1
val_check_interval: null
accumulate_grad_batches: 1
gradient_clip_val: 1.0
callbacks:
- class_path: LearningRateMonitor
- class_path: llm_training.overrides.ModelCheckpoint
init_args:
save_on_train_epoch_end: true
save_top_k: 1

model:
class_path: llm_training.lms.CLM
init_args.config:
model:
model_class: llm_training.models.HFLlava
model_config:
hf_path: <PT_MODEL_PATH>
torch_dtype: bfloat16
attn_implementation: flash_attention_2
enable_gradient_checkpointing: true

frozen_modules:
- vision_tower$

# neftune_alpha: 5.0

optim:
optimizer_class: deepspeed.ops.adam.FusedAdam
optimizer_kwargs:
lr: 2e-5
lr_scheduler_class: llm_training.lr_schedulers.CosineAnnealingWarmupLR
lr_scheduler_kwargs:
num_warmup_steps: 100
min_lr: 2e-6

data:
class_path: llm_training.data.VisualInstructionTuningDataModule
init_args.config:
dataset_kwargs:
path: ShinoharaHare/LLaVA-NeXT-Data-Reformatted
name: 10K
tokenizer:
class_path: HFTokenizer
init_args.path: ShinoharaHare/llava-phi-3.5-mini-instruct_untrained
chat_template: phi-3
image_processor:
class_path: HFImageProcessor
init_args.path: ShinoharaHare/llava-phi-3.5-mini-instruct_untrained
batch_size: 1
max_length: 4096
pad_to_multiple_of: 64
validation_split: null
num_proc: 4
num_workers: 4
enable_cache: true
74 changes: 74 additions & 0 deletions config/examples/llava/llava-phi-3.5_pt_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
seed_everything: 42
float32_matmul_precision: medium
logging_level: DEBUG

trainer:
strategy:
class_path: llm_training.overrides.DeepSpeedStrategy
init_args:
stage: 2
exclude_frozen_parameters: true
precision: bf16-true
logger:
class_path: llm_training.overrides.wandb.WandbLogger
init_args:
name: llava-phi-3.5_pt_example
job_type: example
project: llm-training
save_dir: logs
save_code: true
max_epochs: 1
val_check_interval: null
accumulate_grad_batches: 1
gradient_clip_val: 1.0
callbacks:
- class_path: LearningRateMonitor
- class_path: llm_training.overrides.ModelCheckpoint
init_args:
save_on_train_epoch_end: true
save_top_k: 1

model:
class_path: llm_training.lms.CLM
init_args.config:
model:
model_class: llm_training.models.HFLlava
model_config:
hf_path: ShinoharaHare/llava-phi-3.5-mini-instruct_untrained
torch_dtype: bfloat16
attn_implementation: flash_attention_2
enable_gradient_checkpointing: true

frozen_modules:
- vision_tower$
- language_model$

# neftune_alpha: 5.0

optim:
optimizer_class: deepspeed.ops.adam.FusedAdam
optimizer_kwargs:
lr: 1e-3
lr_scheduler_kwargs:
num_warmup_steps: 100

data:
class_path: llm_training.data.VisualInstructionTuningDataModule
init_args.config:
dataset_kwargs:
path: ShinoharaHare/LLaVA-NeXT-Data-Reformatted
name: 10K
tokenizer:
class_path: HFTokenizer
init_args.path: ShinoharaHare/llava-phi-3.5-mini-instruct_untrained
chat_template: phi-3
image_processor:
class_path: HFImageProcessor
init_args.path: ShinoharaHare/llava-phi-3.5-mini-instruct_untrained
batch_size: 1
max_length: 4096
pad_to_multiple_of: 64
validation_split: null
num_proc: 4
num_workers: 4
enable_cache: true
15 changes: 14 additions & 1 deletion scripts/convert_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,27 @@ def main(
print('Saving model')
hf_model.to(dtype).save_pretrained(output_dir)

if isinstance(datamodule, (PreTrainingDataModule, InstructionTuningDataModule, PreferenceTuningDataModule)):
if isinstance(
datamodule,
(
PreTrainingDataModule,
InstructionTuningDataModule,
PreferenceTuningDataModule,
VisualInstructionTuningDataModule
)
):
print('Saving tokenizer')
tokenizer = datamodule.config.tokenizer
tokenizer.model_max_length = max(tokenizer.model_max_length, datamodule.config.max_length)
chat_template = getattr(datamodule.config, 'chat_template', None)
if chat_template is not None:
tokenizer.chat_template = chat_template
tokenizer.save_pretrained(output_dir)

if isinstance(datamodule, VisualInstructionTuningDataModule):
print('Saving image processor')
image_processor = datamodule.config.image_processor
image_processor.save_pretrained(output_dir)


def convert_checkpoint(path: Path) -> dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions src/llm_training/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .instruction_tuning import *
from .pre_training import *
from .preference_tuning import *
from .visual_instruction_tuning import *
6 changes: 6 additions & 0 deletions src/llm_training/data/visual_instruction_tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .visual_instruction_tuning_datacollator import \
VisualInstructionTuningDataCollator
from .visual_instruction_tuning_datamodule import \
VisualInstructionTuningDataModule
from .visual_instruction_tuning_datamodule_config import \
VisualInstructionTuningDataModuleConfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, TypeVar

import torch

from llm_training.data.base_datacollator import BaseDataCollator

from .visual_instruction_tuning_datamodule_config import \
VisualInstructionTuningDataModuleConfig

T = TypeVar('T')

class VisualInstructionTuningDataCollator(BaseDataCollator):
config: VisualInstructionTuningDataModuleConfig

def __init__(self, config: VisualInstructionTuningDataModuleConfig):
super().__init__(config)

assert 'pad_token' in config.tokenizer.special_tokens_map, '`pad_token` is not specified. Please set it manually.'

def _pad_to_longest(self, batch: list[list[T]], padding_value: T) -> list[list[T]]:
n = max(len(y) for y in batch)
if self.config.pad_to_multiple_of is not None:
n = ((n // self.config.pad_to_multiple_of) + 1) * self.config.pad_to_multiple_of

new_batch = []
for x in batch:
num_paddings = n - len(x)
paddings = [padding_value] * num_paddings
x = paddings + x if self.config.tokenizer.padding_side == 'left' else x + paddings
new_batch.append(x)

return new_batch

def __call__(self, batch: list[dict[str, Any]]):
batch_input_ids = []
batch_attention_mask = []
batch_position_ids = []
batch_labels = []
batch_pixel_values = []

for x in batch:
input_ids = x['input_ids']
n = len(input_ids)

batch_input_ids.append(input_ids)
batch_attention_mask.append([1] * n)
batch_position_ids.append(list(range(n)))
batch_labels.append(x['labels'])
batch_pixel_values.append(x['pixel_values'])

batch_input_ids = self._pad_to_longest(batch_input_ids, self.config.tokenizer.pad_token_id)
batch_attention_mask = self._pad_to_longest(batch_attention_mask, 0)
batch_position_ids = self._pad_to_longest(batch_position_ids, 0)
batch_labels = self._pad_to_longest(batch_labels, -100)

input_ids = torch.tensor(batch_input_ids)
attention_mask = torch.tensor(batch_attention_mask)
position_ids = torch.tensor(batch_position_ids)
labels = torch.tensor(batch_labels)
pixel_values = torch.tensor(batch_pixel_values)

return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids,
'labels': labels,
'pixel_values': pixel_values
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Any

from transformers import BaseImageProcessor, PreTrainedTokenizerBase

from llm_training.data.hf_based.hf_based_datamodule import (DatasetDict,
HFBasedDataModule)

from .visual_instruction_tuning_datacollator import \
VisualInstructionTuningDataCollator
from .visual_instruction_tuning_datamodule_config import \
VisualInstructionTuningDataModuleConfig


class VisualInstructionTuningDataModule(HFBasedDataModule):
config: VisualInstructionTuningDataModuleConfig
datacollator_class = VisualInstructionTuningDataCollator

def __init__(self, config: VisualInstructionTuningDataModuleConfig) -> None:
super().__init__(config)

def pre_process_data(self, dataset_dict: DatasetDict) -> DatasetDict:
dataset_dict = self.map_dataset_dict(
dataset_dict,
_process_text_and_image,
remove_columns=True,
fn_kwargs=dict(
tokenizer=self.config.tokenizer,
chat_template=self.config.chat_template,
image_processor=self.config.image_processor
),
num_proc=self.config.num_proc,
desc='Process text and image'
)

dataset_dict = dataset_dict.filter(
_drop_overlong,
input_columns='input_ids',
fn_kwargs=dict(max_length=self.config.max_length),
num_proc=self.config.num_proc,
desc='Drop overlong'
)

return dataset_dict


def _process_text_and_image(
example: dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
chat_template: str | None,
image_processor: BaseImageProcessor
):
messages = example['messages']
image = example['image']

input_ids = []
labels = []

system_prompt = None
if messages[0]['role'] == 'system':
system_prompt = messages.pop(0)

for i, message in enumerate(messages):
conversation = [message]
if i == 0 and system_prompt is not None:
conversation.insert(0, system_prompt)
text = tokenizer.apply_chat_template(
conversation,
chat_template=chat_template,
tokenize=False,
add_generation_prompt=message['role'] == 'user',
index=i,
length=len(messages)
)
# 這裡將同一筆資料分多次 tokenize,為保證跟一次 tokenize 全部的結果相同
# 先在前面加一個 token,encode 後再移除掉
text = tokenizer.bos_token + text
current_input_ids = tokenizer.encode(text, add_special_tokens=False)
current_input_ids = current_input_ids[1:]

if message['role'] in ['system', 'user']:
input_ids += current_input_ids
labels += [-100] * len(current_input_ids)
elif message['role'] == 'assistant':
input_ids += current_input_ids
labels += current_input_ids
else:
raise ValueError(f"Unknown role: `{message['role']}`")

pixel_values = image_processor(image).pixel_values[0]

return {
'input_ids': input_ids,
'labels': labels,
'pixel_values': pixel_values
}


def _drop_overlong(input_ids: list[int], max_length: int):
return len(input_ids) <= max_length
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging

from pydantic import field_validator
from transformers import BaseImageProcessor, PreTrainedTokenizerBase

from llm_training.data.chat_templates import get_chat_template
from llm_training.data.hf_based.hf_based_datamodule_config import \
HFBasedDataModuleConfig

logger = logging.getLogger(__name__)


class VisualInstructionTuningDataModuleConfig(HFBasedDataModuleConfig):
tokenizer: PreTrainedTokenizerBase
image_processor: BaseImageProcessor
chat_template: str | None = None
max_length: int | None = None
pad_to_multiple_of: int | None = None

@field_validator('chat_template')
@classmethod
def validate_chat_template(cls, value: str | None) -> str | None:
if value is not None:
value = get_chat_template(value)
return value
Loading