This repository contains the code for training llm-jp/llm-jp-3-vila-14b, modified from VILA repository.
Python version: 3.10.12
python3 -m venv venv
source venv/bin/activate
pip install --upgrade pip
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.4.2/flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install -e .
pip install -e ".[train]"
pip install git+https://github.com/huggingface/[email protected]
cp -rv ./llava/train/transformers_replace/* ./venv/lib/python3.10/site-packages/transformers/
See data_prepare/README.md and prepare the training datasets for each step.
There are three stages in the model training.
Tuning the parameters of Projector using English and Japanese image-text pairs datasets. This takes about 14-15 hours on 8xA100 (40G).
script: scripts/mdx/release/1_train_step0.sh
Perform multimodal continual pre-training using a relatively large-scale dataset. This takes about 130 hours on 8x8xA100 (40G).
script: scripts/mdx/release/2_train_step1.sh
Fine-tune the model with multimodal instruction data in both English and Japanese. This takes about 11 hours on 4x8xA100 (40G).
script: scripts/mdx/release/3_train_step2.sh
We used llm-jp-eval-mm for evaluation. Please note that this is currently in beta version.
python -W ignore scripts/mdx/eval/run_inference_ja.py \
--model-path llm-jp/llm-jp-3-vila-14b \
--query "<image>\nこの画像について説明してください。" \
--image-file path/to/image
The code is released under the Apache License, Version 2.0.
This codebase is built upon the following projects: