Skip to content

Commit

Permalink
example: generate
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Nov 25, 2024
1 parent 43c0054 commit 42c8d3e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
19 changes: 15 additions & 4 deletions examples/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,31 @@
from transformers import AutoTokenizer

from xlens.hooked_transformer import HookedTransformer
from xlens.utilities.functional import functional


def example_generate():
model = HookedTransformer.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
input_ids = tokenizer("Hello, my dog is cute.", return_tensors="np")["input_ids"]

@jax.jit
def generate(input_ids: jax.Array, eos_token_id: int, top_k: int = 5, top_p: float = 0.95) -> jax.Array:
@functional(transform=jax.jit)
def generate(
model: HookedTransformer, input_ids: jax.Array, eos_token_id: int, top_k: int = 5, top_p: float = 0.95
) -> jax.Array:
return model.generate(input_ids, eos_token_id, top_k, top_p, rng=jax.random.PRNGKey(42))[0]

generated = generate(input_ids, eos_token_id=tokenizer.eos_token_id)
generated = generate(model, input_ids, eos_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(generated[0]))


def example_generate_no_jit():
model = HookedTransformer.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
input_ids = tokenizer("Hello, my dog is cute.", return_tensors="np")["input_ids"]
generated = model.generate(input_ids, eos_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(generated[0]))


if __name__ == "__main__":
example_generate()
example_generate_no_jit()
4 changes: 4 additions & 0 deletions tests/integration/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def generate(
def generate_with_timeit():
return generate(model, input_ids, eos_token_id=tokenizer.eos_token_id)

print(
"No JIT time:",
timeit.timeit(lambda: model.generate(input_ids, eos_token_id=tokenizer.eos_token_id), number=5) / 5,
)
print("JIT time:", timeit.timeit(generate_with_timeit, number=1))
print("JITted time:", timeit.timeit(generate_with_timeit, number=10) / 10)
generated = generate_with_timeit()
Expand Down

0 comments on commit 42c8d3e

Please sign in to comment.