Skip to content

Commit

Permalink
Merge pull request #83 from Narsil/recipe_for_other_models
Browse files Browse the repository at this point in the history
Adding recipe for other models (non llama, non vicuna).
  • Loading branch information
ctlllll authored Feb 27, 2024
2 parents 700ff84 + 64f4924 commit 5e98053
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 89 deletions.
33 changes: 28 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ accelerate launch -m axolotl.cli.train examples/medusa/your_config.yml

The data preparation code for self-distillation can be found in [`data_generation` folder](data_generation) of the current repo. For other datasets, you can directly download the data from the corresponding Hugging Face dataset repo.

### Training (legacy)
### Training on various architectures
*The following instructions are for the initial release of Medusa, it provides a minimal example of how to train a Medusa-1 model. For the updated version, please refer to the previous section.*

For training, please install:
Expand All @@ -141,14 +141,36 @@ Remark: If you haven't installed `git-lfs`, please install it before cloning:
```bash
git lfs install
```

#### Adapt the data to the model you want to enable medusa on.

Start by launch an inference server you like that will run the model you want to train on.
Let's use [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) as an example.

For instance you can use [text-generation-inference](https://github.com/huggingface/text-generation-inference), which you
can also use after you've trained the medusa heads.

```
model=mistralai/Mistral-7B-Instruct-v0.2
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --input-length 4000 --max-total-tokens 4096 --max-batch-prefill-tokens 4000
```
The sequences in shareGPT are relatively long for some, so make sure you can infer on those. If you do not have enough room, the script will simply ignore those long conversation.
It shouldn't impact too much downstream performance, but more data is always better.
You can use various tradeoffs to [speed up inference](https://huggingface.co/docs/text-generation-inference/index) but the defaults show be good enough in most cases.

```
python create_data.py --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json --output-filename mistral.json
```

#### Train the model
We follow the training setup from [FastChat](https://github.com/lm-sys/FastChat#fine-tuning), but with a much larger learning rate because we freeze the original model and only train the new heads. Here is the training command for the Vicuna-7b model on 4 GPUs. Since we are only training the new heads, the training does not require a lot of memory, and only data parallelism is needed. You can modify the script to fit your own setup. For larger models, we use the same setup. You can also use `--load_in_8bit` or `--load_in_4bit` to load the base model in quantized format.
```bash
torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \
--data_path ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \
torchrun --nproc_per_node=4 medusa/train/train_legacy.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
--data_path mistral.json \
--bf16 True \
--output_dir test \
--num_train_epochs 1 \
--num_train_epochs 2 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 4 \
Expand All @@ -163,7 +185,8 @@ torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vic
--model_max_length 2048 \
--lazy_preprocess True \
--medusa_num_heads 3 \
--medusa_num_layers 1
--medusa_num_layers 1 \
--deepspeed deepspeed.json
```
### Push to Hugging Face Hub
You can use the following command to push your model to the Hugging Face Hub:
Expand Down
76 changes: 76 additions & 0 deletions create_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import typer
import json
from transformers import Conversation
from typing_extensions import Annotated
import httpx
import tqdm
import asyncio

app = typer.Typer()


client = httpx.AsyncClient(timeout=None)

async def run(conv: Conversation):
payload = {"model":"tgi", "messages": conv.messages}
response = await client.post(url, json=payload)
content = response.json()
message = content["choices"][0]["message"]
message.pop("name")
conv.add_message(message)




def fix_source(source):
if source and source[0]["from"] == "gpt":
# Skip if GPT is first to talk
source = source[1:]
new_source = []
for item in source:
role = "assistant" if item["from"] == "gpt" else "user"
content = item["value"]
new_source.append({"role": role, "content": content})
return new_source


async def recreate_conversation(conversation, sem):
async with sem:
conv = Conversation()
try:
for message in conversation[::2]:
assert message["role"] == "user"
conv.add_message(message)
await run(conv)
except Exception:
pass
return conv.messages

@app.command()
def main(
*,
input_filename: Annotated[str, typer.Option("--input-filename")],
output_filename: Annotated[str, typer.Option("--output-filename")],
url: Annotated[str, typer.Option("--url")] = "http://localhost:8080/v1/chat/completions",
concurrency: Annotated[int, typer.Option("--concurrency")] = 64
):
sem = asyncio.Semaphore(concurrency)
async def _main():
with open(input_filename, "r") as f:
input_data = json.loads(f.read())
conversations = [fix_source(source["conversations"]) for source in input_data]

futures = []
for conversation in conversations:
future = recreate_conversation(conversation, sem)
futures.append(future)

recreated_conversations = await tqdm.asyncio.tqdm.gather(*futures)

with open(output_filename, "w") as f:
json.dump(recreated_conversations, f, indent=4)
asyncio.run(_main())


if __name__ == "__main__":
app()
24 changes: 24 additions & 0 deletions deepspeed.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"bf16": {
"enabled": "auto"
},

"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},

"gradient_accumulation_steps": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
15 changes: 9 additions & 6 deletions medusa/model/medusa_model_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def __init__(
super().__init__()
self.base_model = base_model
self.config = base_model.config
self.hidden_size = base_model.lm_head.weight.shape[-1]
self.vocab_size = base_model.lm_head.weight.shape[0]
self.hidden_size = base_model.config.hidden_size
self.vocab_size = base_model.config.vocab_size
self.medusa = medusa_num_heads
self.medusa_num_layers = medusa_num_layers
self.base_model_name_or_path = base_model_name_or_path
Expand All @@ -110,9 +110,12 @@ def __init__(
# Ensure medusa_head's dtype and device align with the base_model
self.medusa_head.to(self.base_model.dtype).to(self.base_model.device)

for i in range(medusa_num_heads):
# Initialize the weights of each medusa_head using the base model's weights
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]
import deepspeed
params = [base_model.lm_head.weight]
with deepspeed.zero.GatheredParameters(params):
for i in range(medusa_num_heads):
# Initialize the weights of each medusa_head using the base model's weights
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]

def get_tokenizer(self):
"""Get the tokenizer of the base model.
Expand Down Expand Up @@ -189,7 +192,7 @@ def forward(
torch.Tensor: A tensor containing predictions from all Medusa heads.
(Optional) Original predictions from the base model's LM head.
"""
with torch.inference_mode():
with torch.no_grad():
# Pass input through the base model
outputs = self.base_model.model(
input_ids=input_ids,
Expand Down
Loading

0 comments on commit 5e98053

Please sign in to comment.