diff --git a/examples/medusa/README.md b/examples/medusa/README.md index eb0bef11a..71f4e1576 100644 --- a/examples/medusa/README.md +++ b/examples/medusa/README.md @@ -1,3 +1,5 @@ +# Liger-Kernel Example with Medusa + Medusa is a simple framework that democratizes the acceleration techniques for LLM generation with multiple decoding heads. [[repo](https://arxiv.org/abs/2401.10774)], [[paper](https://arxiv.org/abs/2401.10774)] During training, Medusa requires adding \(k\) decoding heads to the hidden states right before the regular LM head \(h_t\). The \(k\)-th head is used to predict the token in the \((t + k + 1)\)-th position of the next tokens (the original language model head is used to predict the \((t + 1)\)-th position). @@ -16,6 +18,14 @@ pip install -r requirements.txt sh scripts/llama3_8b_medusa.sh ``` +**Notes** +1. This example uses an optional `use_liger` flag. If true, it does a monkey patch to apply liger kernel with medusa heads. +2. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the followings: + * Agree on the community license agreement https://huggingface.co/meta-llama/Meta-Llama-3-8B + * Run `huggingface-cli login` and enter your HuggingFace token +3. The default hyperparameters and configurations work on single node with 8xA100 GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP. + + # Memory Profiling Result > **Note:** diff --git a/examples/medusa/scripts/llama3_8b_medusa.sh b/examples/medusa/scripts/llama3_8b_medusa.sh index faa9e96e9..57c3f9dcd 100644 --- a/examples/medusa/scripts/llama3_8b_medusa.sh +++ b/examples/medusa/scripts/llama3_8b_medusa.sh @@ -6,8 +6,7 @@ export NUM_NODES=$WORLD_SIZE export WORLD_SIZE=$((GPUS_PER_NODE * NUM_NODES)) echo "Starting training... Num nodes: $NUM_NODES, Num workers: $WORLD_SIZE" -export OUTPUT_DIR="/shared/user/Meta-Llama-3-70B-Instruct-code-act-3ep" -export DATA_PATH="/shared/public/data/jaszhu/medusa/ShareGPT_V4.3_unfiltered_cleaned_split.json" +export OUTPUT_DIR="./llama3-8b-medusa-liger" export LOCAL_TRAIN_BATCH_SIZE=4 export GRADIENT_ACCUMULATION_STEPS=1 @@ -27,8 +26,6 @@ accelerate launch --config_file fsdp/acc-fsdp.conf \ --main_process_port $MASTER_PORT \ --machine_rank $RANK \ train.py \ - --model_name_or_path /shared/public/models/Meta-Llama-3-8B-Instruct \ - --data_path $DATA_PATH \ --bf16 True \ --output_dir $OUTPUT_DIR \ --num_train_epochs 10 \ @@ -56,4 +53,4 @@ accelerate launch --config_file fsdp/acc-fsdp.conf \ --medusa_lr_multiplier $MEDUSA_LR_MULTIPLIER \ --medusa_only_heads False \ --medusa_return True \ - --with_liger True \ No newline at end of file + --use_liger True \ No newline at end of file diff --git a/examples/medusa/train.py b/examples/medusa/train.py index 3d9be2b1e..4eebf27d9 100644 --- a/examples/medusa/train.py +++ b/examples/medusa/train.py @@ -38,14 +38,14 @@ @dataclass class ModelArguments: model_name_or_path: Optional[str] = field( - default="/shared/public/models/Meta-Llama-3-8B-Instruct" + default="meta-llama/Meta-Llama-3-8B" ) @dataclass class DataArguments: data_path: str = field( - default="sharegpt_clean.json", + default="Aeala/ShareGPT_Vicuna_unfiltered", metadata={"help": "Path to the training data."}, ) eval_data_path: str = field( @@ -99,7 +99,7 @@ class TrainingArguments(transformers.TrainingArguments): "help": "If train medusa heads only, default is False, the whole model will be trained" }, ) - with_liger: bool = field( + use_liger: bool = field( default=False, metadata={"help": "If apply liger kernel to the model."}, ) @@ -331,7 +331,7 @@ def train(): torch_dtype=torch.bfloat16, ) - if training_args.with_liger is True: + if training_args.use_liger is True: apply_liger_kernel_to_llama() # Freeze the base model @@ -344,7 +344,7 @@ def train(): training_args.medusa_num_layers, training_args.medusa_return, training_args.medusa_only_heads, - training_args.with_liger, + training_args.use_liger, ) # Format output dir training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"