Skip to content

Commit

Permalink
Replace model / data with public HF path, update readme (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonZhu1313 authored Aug 23, 2024
1 parent 95ac0af commit b418557
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
10 changes: 10 additions & 0 deletions examples/medusa/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Liger-Kernel Example with Medusa

Medusa is a simple framework that democratizes the acceleration techniques for LLM generation with multiple decoding heads. [[repo](https://arxiv.org/abs/2401.10774)], [[paper](https://arxiv.org/abs/2401.10774)]

During training, Medusa requires adding \(k\) decoding heads to the hidden states right before the regular LM head \(h_t\). The \(k\)-th head is used to predict the token in the \((t + k + 1)\)-th position of the next tokens (the original language model head is used to predict the \((t + 1)\)-th position).
Expand All @@ -16,6 +18,14 @@ pip install -r requirements.txt
sh scripts/llama3_8b_medusa.sh
```

**Notes**
1. This example uses an optional `use_liger` flag. If true, it does a monkey patch to apply liger kernel with medusa heads.
2. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the followings:
* Agree on the community license agreement https://huggingface.co/meta-llama/Meta-Llama-3-8B
* Run `huggingface-cli login` and enter your HuggingFace token
3. The default hyperparameters and configurations work on single node with 8xA100 GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP.


# Memory Profiling Result

> **Note:**
Expand Down
7 changes: 2 additions & 5 deletions examples/medusa/scripts/llama3_8b_medusa.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ export NUM_NODES=$WORLD_SIZE
export WORLD_SIZE=$((GPUS_PER_NODE * NUM_NODES))
echo "Starting training... Num nodes: $NUM_NODES, Num workers: $WORLD_SIZE"

export OUTPUT_DIR="/shared/user/Meta-Llama-3-70B-Instruct-code-act-3ep"
export DATA_PATH="/shared/public/data/jaszhu/medusa/ShareGPT_V4.3_unfiltered_cleaned_split.json"
export OUTPUT_DIR="./llama3-8b-medusa-liger"

export LOCAL_TRAIN_BATCH_SIZE=4
export GRADIENT_ACCUMULATION_STEPS=1
Expand All @@ -27,8 +26,6 @@ accelerate launch --config_file fsdp/acc-fsdp.conf \
--main_process_port $MASTER_PORT \
--machine_rank $RANK \
train.py \
--model_name_or_path /shared/public/models/Meta-Llama-3-8B-Instruct \
--data_path $DATA_PATH \
--bf16 True \
--output_dir $OUTPUT_DIR \
--num_train_epochs 10 \
Expand Down Expand Up @@ -56,4 +53,4 @@ accelerate launch --config_file fsdp/acc-fsdp.conf \
--medusa_lr_multiplier $MEDUSA_LR_MULTIPLIER \
--medusa_only_heads False \
--medusa_return True \
--with_liger True
--use_liger True
10 changes: 5 additions & 5 deletions examples/medusa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(
default="/shared/public/models/Meta-Llama-3-8B-Instruct"
default="meta-llama/Meta-Llama-3-8B"
)


@dataclass
class DataArguments:
data_path: str = field(
default="sharegpt_clean.json",
default="Aeala/ShareGPT_Vicuna_unfiltered",
metadata={"help": "Path to the training data."},
)
eval_data_path: str = field(
Expand Down Expand Up @@ -99,7 +99,7 @@ class TrainingArguments(transformers.TrainingArguments):
"help": "If train medusa heads only, default is False, the whole model will be trained"
},
)
with_liger: bool = field(
use_liger: bool = field(
default=False,
metadata={"help": "If apply liger kernel to the model."},
)
Expand Down Expand Up @@ -331,7 +331,7 @@ def train():
torch_dtype=torch.bfloat16,
)

if training_args.with_liger is True:
if training_args.use_liger is True:
apply_liger_kernel_to_llama()

# Freeze the base model
Expand All @@ -344,7 +344,7 @@ def train():
training_args.medusa_num_layers,
training_args.medusa_return,
training_args.medusa_only_heads,
training_args.with_liger,
training_args.use_liger,
)
# Format output dir
training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"
Expand Down

0 comments on commit b418557

Please sign in to comment.