Skip to content

Commit

Permalink
add phi-3-mini eager mode example (pytorch#4315)
Browse files Browse the repository at this point in the history
Summary:
This PR adds an example script in eager mode which does inference with/without kv cache enabled. This comes handy when we want to verify that phi-3-mini works in eager mode.

Pull Request resolved: pytorch#4315

Test Plan:
```
python3 -m examples.models.phi-3-mini.eager -s 128 -kv -p "Tell me a story"
python3 -m examples.models.phi-3-mini.eager -s 128 -p "Tell me a story"
```
Verify that the model runs faster with kv cache enabled.

Reviewed By: JacobSzwejbka

Differential Revision: D60061822

Pulled By: helunwencser

fbshipit-source-id: 483d2f9e56f9397f78dec805a0c1a110cb1cfc28
  • Loading branch information
helunwencser authored and facebook-github-bot committed Jul 25, 2024
1 parent 5b0700b commit dbf87b0
Showing 1 changed file with 122 additions and 0 deletions.
122 changes: 122 additions & 0 deletions examples/models/phi-3-mini/eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


# Script to run phi-3-mini model in eager mode.

import argparse
import time

import torch

from transformers import AutoTokenizer, Phi3ForCausalLM

end_of_text_token = 32000


def _generate_token(args, model, prompt_tokens):
current_token = 0
generated_tokens = []

print("Generating tokens:", end="", flush=True)

while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
outputs = model.forward(input_ids=prompt_tokens)
current_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).item()
print(f" {current_token}", end="", flush=True)
generated_tokens.append(current_token)
prompt_tokens = torch.cat(
[prompt_tokens, torch.tensor([[current_token]], dtype=torch.long)], dim=-1
)

print("", flush=True)

return generated_tokens


def _generate_token_with_kv_cache(args, model, prompt_tokens):
print("Generating tokens:", end="", flush=True)

result = model.forward(input_ids=prompt_tokens, use_cache=True, return_dict=True)

current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item()
current_key_value = result.past_key_values

print(f" {current_token}", end="", flush=True)

generated_tokens = [current_token]

while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
result = model.forward(
input_ids=torch.tensor([[current_token]], dtype=torch.long),
use_cache=True,
return_dict=True,
past_key_values=current_key_value,
)
current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item()
current_key_value = result.past_key_values
print(f" {current_token}", end="", flush=True)
generated_tokens.append(current_token)

print("", flush=True)

return generated_tokens


def main(args):
seed = 42
torch.manual_seed(seed)
model_name = "microsoft/Phi-3-mini-4k-instruct"
model = Phi3ForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokens = tokenizer.encode(args.prompt, return_tensors="pt")

start = time.time()
generated_tokens = (
_generate_token_with_kv_cache(args, model, tokens)
if args.use_kv_cache
else _generate_token(args, model, tokens)
)
end = time.time()

print(
"Generated response: \n {}".format(
tokenizer.decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
),
flush=True,
)
print(f"Time spent: {end - start}", flush=True)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-s",
"--seq_len",
type=int,
default=128,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"-kv",
"--use_kv_cache",
default=False,
action="store_true",
help="Whether or not to use KV cache",
)
parser.add_argument(
"-p",
"--prompt",
type=str,
default="Tell me a story",
help="Prompt as input for the model",
)
main(parser.parse_args())

0 comments on commit dbf87b0

Please sign in to comment.