Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated LLM samples with support for Phi 3.5 mini and Llama 3.1 #646

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 35 additions & 37 deletions PyTorch/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,33 @@ This sample is extracted from [pytorch-labs/gpt-fast](https://github.com/pytorch
- [Setup](#setup)
- [Run the App](#run-the-app)
- [App Settings](#app-settings)
- [External Links](#external-links)
- [Model Licenses](#model-licenses)
- [External Links & Model Licenses](#external-links-and-model-licenses)

## Supported Models

The following models are currently supported by this sample:

- [Phi-2](https://huggingface.co/microsoft/phi-2): Small Language Model with 2.7 billion parameters. Best suited for prompts using QA format, chat format, and code format.
- [Phi-3 Mini 4K](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct): Small Language Model with 3.8 billion parameters using a 4k context window. The Instruct version has been fine-tuned to follow instructions and adhere to safety measures.
- [Phi-3.5 Mini](https://huggingface.co/microsoft/Phi-3.5-mini-instruct): A lightweight, state-of-the-art open model with 3.8 billion parameters using a 128k context window. The Instruct version has been fine-tuned to ensure precise instruction adherence and robust safety measures.
- [LLaMA 2](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf): Large Language Model with 7 billion parameters optimized specifically for dialogue use cases.
- [LLaMA 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct): Large Language Model with 8 billion parameters. The Llama 3 instruction tuned models are optimized for dialogue use cases.
- [LLaMA 3.1](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct): Large Language Model with 8 billion parameters. The Llama 3.1 instruction tuned models are an inprovement over Llama-3 and are optimized for dialogue use cases.
- [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1): Large Language Model with 7 billion parameters. The Mistral-7B-Instruct-v0.1 Large Language Model is a instruct fine-tuned version of the Mistral-7B-v0.1 generative text model using a variety of publicly available conversation datasets.

>⚠️ **NOTE**: Other variants of these models may work but they were not tested.

The various models have different VRAM requirements, the following table lists the memory requirements for the tested models.

| Model | fp16 | fp32 |
| --------------- | ------| ----- |
| Phi-2 | 6GB | 12GB |
| Phi-3-mini-4k | 8GB | >16GB |
| Llama-2-7b | 14GB | 28GB |
| Meta-Llama-3-8B | >16GB | 32GB |
| Mistral-7B | 15GB | 30GB |
| Model | fp16 | fp32 |
| ----------------- | ------| ----- |
| Phi-2 | 6GB | 12GB |
| Phi-3-mini-4k | 8GB | >16GB |
| Phi-3.5-mini | 8GB | >16GB |
| Llama-2-7b | 14GB | 28GB |
| Meta-Llama-3-8B | >16GB | 32GB |
| Meta-Llama-3.1-8B | >16GB | 32GB |
| Mistral-7B | 15GB | 30GB |

## Setup
Once you've setup `torch-directml` following our [Windows](https://learn.microsoft.com/windows/ai/directml/pytorch-windows) or [WSL 2](https://learn.microsoft.com/windows/ai/directml/pytorch-wsl) guidance, install the following requirements for running app:
Expand All @@ -44,6 +47,7 @@ To use the Llama and Mistral models, you will need to go through an extra step t
1. Visit
- LLaMA 2: [https://huggingface.co/meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
- LLaMA 3: [https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- LLaMA 3.1: [https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
- Mistral: [https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
2. Follow the steps on the Hugging Face page to obtain access
3. Run `huggingface-cli login`
Expand Down Expand Up @@ -104,7 +108,7 @@ To run the model using `float32` precision, pass `--precision float32` to `app.p

### Change the model

You can also select another model to run (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/phi-2`, `meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Llama-2-7b-chat-hf`, `mistralai/Mistral-7B-Instruct-v0.1`).
You can also select another model to run (`microsoft/Phi-3.5-mini-instruct`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/phi-2`, `meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Llama-2-7b-chat-hf`, `mistralai/Mistral-7B-Instruct-v0.1`).

For example to run `Mistral-7B-Instruct-v0.1` use the following command:

Expand Down Expand Up @@ -151,16 +155,20 @@ Following is a list of the basic settings supported by `app.py`:
| `--checkpoint_path` | Path to converted PyTorch model checkpoint. | `checkpoints/{hf_model}/model.pth` |
| `--max_context_length` | Max prompt length including the history. If exceeded, history is clipped starting from the first (user, assistant) pair. | `1500` |
| `--disable_history` | Disable the chat history during generation. | Enabled |
| `--max_pos_emb` | Maximum Position to scale Phi-3.5 position encodings. | 8192 |

>⚠️ **NOTE**: The app uses the checkpoint path to determine the correct transformer model to load. The model path must specify the Hugging Face model ID included in the path name. For example:

- `checkpoints/microsoft/phi-2/model.pth`
- `checkpoints/microsoft/Phi-3-mini-4k-instruct/model.pth`
- `checkpoints/microsoft/Phi-3.5-mini-instruct/model.pth`
- `checkpoints/mistralai/Mistral-7B-v0.1/model.pth`
- `checkpoints/mistralai/Mistral-7B-Instruct-v0.1/model.pth`
- `checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth`
- `checkpoints/meta-llama/Meta-Llama-3-8B/model.pth`
- `checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth`
- `checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth`
- `checkpoints/meta-llama/Meta-Llama-3.1-8B-Instruct/model.pth`

## _[Optional]_ Prepare the Supported Models
This step is optional as `app.py` script in [Run the App](#run-the-app) section handles both downloading and optimizing a PyTorch model with DirectML.
Expand All @@ -179,36 +187,26 @@ After the model is downloaded and converted, you can pass the following paramete
> python app.py --hf_model "microsoft/Phi-3-mini-4k-instruct"
```

### Download a DirectML optimized PyTorch model from the [Microsoft Hugging Face repo](https://huggingface.co/microsoft):

1. cd checkpoints
2. git clone https://huggingface.co/{hf_model} {hf_model}
3. cd ../

After the model is downloaded, you can pass the following parameter to `app.py` to run the language model:

```
> python app.py --checkpoint_path "checkpoints/{hf_model}/model.pth"
```

## External Links
- [Phi-2 Hugging Face Repository](https://huggingface.co/microsoft/phi-2)
- [Phi-3 Hugging Face Repository](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [LLaMA 2 Hugging Face Repository](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
- [LLaMA 3 Hugging Face Repository](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- [Mistral 7B Hugging Face Repository](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
## External Links and Model Licenses
- [Phi-2 Hugging Face Repository](https://huggingface.co/microsoft/phi-2)
This sample uses the Phi-2 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [MIT license](https://huggingface.co/microsoft/phi-2/resolve/main/LICENSE). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms.
- [Phi-3 Hugging Face Repository](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
This sample uses the Phi-3 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/LICENSE). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms.
- [Phi-3.5 Hugging Face Repository](https://huggingface.co/microsoft/Phi-3.5-mini-instruct)
This sample uses the Phi-3.5 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3.5-mini-instruct/resolve/main/LICENSE). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms.
- [LLaMA 2 Hugging Face Repository](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
This sample uses the Llama-2 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the[ LLAMA 2 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/LICENSE.txt). For terms of use, please visit: Llama 2 - [Acceptable Use Policy - Meta AI](https://ai.meta.com/llama/use-policy/). If you comply with the license and terms of use, you have the rights described therein. By using the Sample, you accept the terms.
- [LLaMA 3 Hugging Face Repository](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
This sample uses the Llama-3 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the
[LLAMA 3 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE). For terms of use, please visit: [Meta Llama 3 Acceptable Use Policy](https://llama.meta.com/llama3/use-policy/). If you comply with the license and terms of use, you have the rights described therein. By using the Sample, you accept the terms.
- [LLaMA 3.1 Hugging Face Repository](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
This sample uses the Llama-3.1 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the
[LLAMA 3.1 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/LICENSE). For terms of use, please visit: [Meta Llama 3.1 Acceptable Use Policy](https://llama.meta.com/llama3_1/use-policy/). If you comply with the license and terms of use, you have the rights described therein. By using the Sample, you accept the terms.
- [Mistral 7B Hugging Face Repository](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
This sample uses the Mistral model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [Apache-2.0 license](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms.
- [PyTorch gpt-fast Source Code](https://github.com/pytorch-labs/gpt-fast/)

## Model Licenses

- [DirectML-Optimized Phi-2 Hugging Face Repository](https://huggingface.co/microsoft/phi-2-pytdml)
This sample uses the phi-2 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [MIT license](https://huggingface.co/microsoft/phi-2/resolve/main/LICENSE). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms.

- [DirectML-Optimized Phi-3 Hugging Face Repository](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-pytdml)
This sample uses the phi-3 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/LICENSE). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms.

- [DirectML-Optimized LLaMA 2 Hugging Face Repository](https://huggingface.co/microsoft/Llama-2-7b-chat-hf-pytdml)
This sample uses the Llama-2 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the[ LLAMA 2 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/LICENSE.txt). For terms of use, please visit: Llama 2 - [Acceptable Use Policy - Meta AI](https://ai.meta.com/llama/use-policy/). If you comply with the license and terms of use, you have the rights described therein. By using the Sample, you accept the terms.

- [DirectML-Optimized Mistral 7B Hugging Face Repository](https://huggingface.co/microsoft/Mistral-7B-Instruct-v0.1-pytdml)
This sample uses the Mistral model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [Apache-2.0 license](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms.
23 changes: 17 additions & 6 deletions PyTorch/llm/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def __init__(
precision: str = 'float32',
stream_every_n: int = 7,
max_context_length: int = 3500,
use_history: bool = False
use_history: bool = False,
max_pos_emb: int = 8192
):
self.prompt = prompt
self.interactive = interactive
Expand All @@ -139,6 +140,7 @@ def __init__(
self.stream_every_n = stream_every_n
self.max_context_length = max_context_length
self.use_history = use_history
self.max_pos_emb = max_pos_emb

self.tokenizer = None
self.model = None
Expand Down Expand Up @@ -177,8 +179,7 @@ def format_prompt_and_encode(
messages.append(assistant)
messages.append({"role": "user", "content": prompt})
tokens = self.tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=self.is_llama_3)[0].to(dtype=torch.int, device=device)

messages, return_tensors="pt", add_generation_prompt=True)[0].to(dtype=torch.int, device=device)
if self.use_history:
while tokens.size(0) > max_context_length:
print("Clipping history of conversation as it exceeds the max context length.")
Expand All @@ -188,7 +189,7 @@ def format_prompt_and_encode(
else:
break
tokens = self.tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=self.is_llama_3)[0].to(dtype=torch.int, device=device)
messages, return_tensors="pt", add_generation_prompt=True)[0].to(dtype=torch.int, device=device)

return tokens

Expand Down Expand Up @@ -274,7 +275,7 @@ def load_model(self) -> None:
if self.is_phi_2:
self.precision = torch.float32

self.model = _load_model(self.checkpoint_path, device, self.precision)
self.model = _load_model(self.checkpoint_path, device, self.precision, max_pos_emb=self.max_pos_emb)
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint_path.parent)
if self.max_context_length > self.model.config.block_size - (self.max_new_tokens+1):
raise ValueError(
Expand All @@ -288,6 +289,7 @@ def chat(
**sampling_kwargs
) -> Iterator[str]:
torch.manual_seed(1235)

encoded = self.encode_tokens(
prompt,
history,
Expand Down Expand Up @@ -348,7 +350,15 @@ def chat(message: str, history: List[List[str]]) -> Iterator[str]:
choices=['float16', 'float32'],
help='Precision to run the generation with.'
)
parser.add_argument(
'--max_pos_emb',
type=int,
default=8192,
help='Maximum Position to scale Phi-3.5 position encodings.'
)
args = parser.parse_args()
if args.max_pos_emb > 131072:
args.max_pos_emb = 131072

llm_model = LLM_Model(prompt = "Hello",
interactive = False,
Expand All @@ -360,7 +370,8 @@ def chat(message: str, history: List[List[str]]) -> Iterator[str]:
checkpoint_path = args.checkpoint_path,
precision = args.precision,
max_context_length = args.max_context_length,
use_history = not args.disable_history)
use_history = not args.disable_history,
max_pos_emb=args.max_pos_emb)
llm_model.load_model()

demo = gr.ChatInterface(chat).queue()
Expand Down
11 changes: 9 additions & 2 deletions PyTorch/llm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, config: ModelArgs) -> None:
self.max_batch_size = -1
self.max_seq_length = -1

def set_max_position_embeddings(self, max_pos_emb):
self.max_position_embeddings = max_pos_emb

def setup_caches(self, max_batch_size, max_seq_length):
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
Expand All @@ -49,5 +52,9 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
return logits

@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))
def from_name(cls, name: str, max_pos_emb: int = 8192):
model = cls(ModelArgs.from_name(name))
if "phi-3.5" in name.lower():
model.set_max_position_embeddings(max_pos_emb)

return model
Loading