Skip to content

Commit

Permalink
Update our sampler documentation to reflect usage (keras-team#1444)
Browse files Browse the repository at this point in the history
We will update our samplers in the near future to push the backend
specific compilation details out: keras-team#1425

Also in general, we want our documentation to reflect the main usage of
our classes, which is using them with Seq2SeqLM and CausalLM classes.

So with that in mind, this updates our sampler docs to show the
practical usage of the sampling classes with our modeling classes. For
the base class, we show the main use case of overriding the
`get_next_token()` function.
  • Loading branch information
mattdangerw authored and abuelnasr0 committed Apr 2, 2024
1 parent ab7b48a commit 996fc48
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 208 deletions.
59 changes: 9 additions & 50 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import ops
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.BeamSampler")
class BeamSampler(Sampler):
"""Beam Sampler class.
Expand All @@ -42,55 +39,17 @@ class BeamSampler(Sampler):
{{call_args}}
Examples:
Return only the beam with the highest accumulated probability.
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)
def next(prompt, cache, index):
prompt_batch_size = tf.shape(prompt)[0]
hidden_states = np.ones((prompt_batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((prompt_batch_size, vocab_size))
return logits, hidden_states, cache
output = keras_nlp.samplers.BeamSampler()(
next=next,
prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzeeeeeee']
```
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
Return all beams and their probabilities.
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 8, len(int_lookup)
def next(prompt, cache, index):
prompt_batch_size = tf.shape(prompt)[0]
hidden_states = np.ones((prompt_batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((batch_size, vocab_size))
return logits, hidden_states, cache
beams, probs = keras_nlp.samplers.BeamSampler(return_all_beams=True)(
next=next,
prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"),
index=5,
)
print(beams.shape)
# >>> (1, 5, 8)
print(probs.shape)
# >>> (1, 5)
print(["".join([int_lookup[i] for i in s]) for s in beams[0].numpy()])
# >>> ['zzzzzeee', 'zzzzzeed', 'zzzzzeec', 'zzzzzeea', 'zzzzzeeb']
# Pass by name to compile.
causal_lm.compile(sampler="beam")
causal_lm.generate(["Keras is a"])
# Pass by object to compile.
sampler = keras_nlp.samplers.BeamSampler(num_beams=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""

Expand Down
35 changes: 10 additions & 25 deletions keras_nlp/samplers/contrastive_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import ops
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.ContrastiveSampler")
class ContrastiveSampler(Sampler):
"""Contrastive Sampler class.
Expand All @@ -44,28 +41,16 @@ class ContrastiveSampler(Sampler):
Examples:
```python
# Use a simple alphabet of lowercase characters to [0, 26).
int_lookup = {i: chr(i + ord("a")) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)
hidden_size = 5
index = 5
def next(prompt, cache, index):
prompt_batch_size = tf.shape(prompt)[0]
hidden_states = np.ones((prompt_batch_size, hidden_size))
# A uniform distribution over our alphabet.
logits = np.ones((prompt_batch_size, vocab_size))
return logits, hidden_states, cache
output = keras_nlp.samplers.ContrastiveSampler()(
next=next,
prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"),
index=index,
hidden_states=np.ones([batch_size, index, hidden_size]),
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> "zzzzzeeeeeee"
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Pass by name to compile.
causal_lm.compile(sampler="contrastive")
causal_lm.generate(["Keras is a"])
# Pass by object to compile.
sampler = keras_nlp.samplers.ContrastiveSampler(k=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""

Expand Down
34 changes: 10 additions & 24 deletions keras_nlp/samplers/greedy_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,27 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import ops
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.GreedySampler")
class GreedySampler(Sampler):
"""Greedy sampler class.
This sampler is implemented on greedy search, i.e., always picking up the
token of the largest probability as the next token.
Call arguments:
{{call_args}}
Examples:
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)
def next(prompt, cache, index):
hidden_states = np.ones((batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((batch_size, vocab_size))
return logits, hidden_states, cache
output = keras_nlp.samplers.GreedySampler()(
next=next,
prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzaaaaaaa']
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Pass by name to compile.
causal_lm.compile(sampler="greedy")
causal_lm.generate(["Keras is a"])
# Pass by object to compile.
sampler = keras_nlp.samplers.GreedySampler()
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""

Expand Down
27 changes: 8 additions & 19 deletions keras_nlp/samplers/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.RandomSampler")
class RandomSampler(Sampler):
"""Random Sampler class.
Expand All @@ -37,24 +34,16 @@ class RandomSampler(Sampler):
Examples:
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
def next(prompt, state, index):
hidden_states = np.ones((batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((batch_size, vocab_size))
return logits, hidden_states, state
# Pass by name to compile.
causal_lm.compile(sampler="random")
causal_lm.generate(["Keras is a"])
output = keras_nlp.samplers.RandomSampler()(
next=next,
prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzcpnjqij']
# Pass by object to compile.
sampler = keras_nlp.samplers.RandomSampler(temperature=0.7)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""

Expand Down
76 changes: 24 additions & 52 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,8 @@
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.utils.python_utils import format_docstring

call_args_docstring = """next: A function which takes in the
`prompt, cache, index` of the current generation loop, and outputs
a tuple `(logits, hidden_states, cache)` with `logits` being the
logits of next token, `hidden_states` being the representation of
the next token, and `cache` for next iteration.
prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This
tensor will be iteratively updated column by column with new sampled
values, starting at `index`.
cache: Optional. A tensor or nested structure of tensors that will be
updated by each call to `next`. This can be used to cache
computations from early iterations of the generative loop.
index: Optional. The first index of `prompt` to start sampling at.
Usually this is set as the length of the shortest non-padded
sequence in `prompt`.
mask: Optional. A 2D integer tensor with the same shape as `prompt`.
Locations which are `True` in the mask are never updated during
sampling. Usually used to mark all locations in the dense prompt
tensor which were present in a user input.
end_token_id: Optional. The token marking the end of the sequence. If
specified, sampling will stop as soon as all sequences in the prompt
produce a `end_token_id` in a location where `mask` is `False`.
"""


@format_docstring(call_args=call_args_docstring)


@keras_nlp_export("keras_nlp.samplers.Sampler")
class Sampler:
"""Base sampler class.
Expand All @@ -57,35 +32,32 @@ class Sampler:
{{call_args}}
This base class can be extended to implement different auto-regressive
sampling methods. Subclasses can either:
- Override the `get_next_token()` method, which computes the next token
based on a probability distribution over all possible vocab entries.
- Override `__call__`, if the sampling method needs additional information
beyond the next tokens probability distribution to sample a sequence.
Please check available subclass samplers for examples.
sampling methods. To do so, override the `get_next_token()` method, which
computes the next token based on a probability distribution over all
possible vocab entries.
Examples:
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)
def next(prompt, cache, index):
# return a uniform distribution over our alphabet.
logits = ops.ones((batch_size, vocab_size))
return logits, None, cache
output = keras_nlp.samplers.GreedySampler()(
next=next,
prompt=ops.fill((batch_size, length,), char_lookup['z']),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzaaaaaaa']
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Greedy search with some tokens forbidden.
class CustomSampler(keras_nlp.samplers.Sampler):
def __init__(self, forbidden_tokens, **kwargs):
super().__init__(**kwargs)
self.forbidden_tokens = forbidden_tokens
def get_next_token(self, probs):
batch_size, vocab_size = keras.ops.shape(probs)
for id in self.forbidden_tokens:
update = keras.ops.zeros((batch_size, 1))
probs = keras.ops.slice_update(probs, (0, id), update)
return keras.ops.argmax(probs, axis=-1)
# 257 = "a" with a leading space, 262 = "the" with a leading space.
causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262]))
causal_lm.summary()
causal_lm.generate(["That's strange"])
```
"""

Expand Down
27 changes: 8 additions & 19 deletions keras_nlp/samplers/top_k_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(call_args=call_args_docstring)
@keras_nlp_export("keras_nlp.samplers.TopKSampler")
class TopKSampler(Sampler):
"""Top-K Sampler class.
Expand All @@ -38,24 +35,16 @@ class TopKSampler(Sampler):
Examples:
```python
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
def next(prompt, cache, index):
hidden_states = np.ones((batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((batch_size, vocab_size))
return logits, hidden_states, cache
# Pass by name to compile.
causal_lm.compile(sampler="top_k")
causal_lm.generate(["Keras is a"])
output = keras_nlp.samplers.TopKSampler(k=3)(
next=next,
prompt=np.full((batch_size, length,), char_lookup['z'], dtypes="int32"),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> ['zzzzzacbbcaa']
# Pass by object to compile.
sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""

Expand Down
Loading

0 comments on commit 996fc48

Please sign in to comment.