From 996fc4867239ae1303ff2a97e177319018844c47 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 20 Feb 2024 15:56:13 -0800 Subject: [PATCH] Update our sampler documentation to reflect usage (#1444) We will update our samplers in the near future to push the backend specific compilation details out: https://github.com/keras-team/keras-nlp/pull/1425 Also in general, we want our documentation to reflect the main usage of our classes, which is using them with Seq2SeqLM and CausalLM classes. So with that in mind, this updates our sampler docs to show the practical usage of the sampling classes with our modeling classes. For the base class, we show the main use case of overriding the `get_next_token()` function. --- keras_nlp/samplers/beam_sampler.py | 59 +++--------------- keras_nlp/samplers/contrastive_sampler.py | 35 +++-------- keras_nlp/samplers/greedy_sampler.py | 34 +++------- keras_nlp/samplers/random_sampler.py | 27 +++----- keras_nlp/samplers/sampler.py | 76 +++++++---------------- keras_nlp/samplers/top_k_sampler.py | 27 +++----- keras_nlp/samplers/top_p_sampler.py | 27 +++----- 7 files changed, 77 insertions(+), 208 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 9562f95d14..87948439a8 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -18,11 +18,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.BeamSampler") class BeamSampler(Sampler): """Beam Sampler class. @@ -42,55 +39,17 @@ class BeamSampler(Sampler): {{call_args}} Examples: - Return only the beam with the highest accumulated probability. ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((prompt_batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.BeamSampler()( - next=next, - prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzeeeeeee'] - ``` + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - Return all beams and their probabilities. - ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 8, len(int_lookup) - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache - - beams, probs = keras_nlp.samplers.BeamSampler(return_all_beams=True)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - - print(beams.shape) - # >>> (1, 5, 8) - print(probs.shape) - # >>> (1, 5) - print(["".join([int_lookup[i] for i in s]) for s in beams[0].numpy()]) - # >>> ['zzzzzeee', 'zzzzzeed', 'zzzzzeec', 'zzzzzeea', 'zzzzzeeb'] + # Pass by name to compile. + causal_lm.compile(sampler="beam") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.BeamSampler(num_beams=5) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index bac65bcfbe..8b3d52d9a5 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -17,11 +17,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.ContrastiveSampler") class ContrastiveSampler(Sampler): """Contrastive Sampler class. @@ -44,28 +41,16 @@ class ContrastiveSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters to [0, 26). - int_lookup = {i: chr(i + ord("a")) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - hidden_size = 5 - index = 5 - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, hidden_size)) - # A uniform distribution over our alphabet. - logits = np.ones((prompt_batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.ContrastiveSampler()( - next=next, - prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"), - index=index, - hidden_states=np.ones([batch_size, index, hidden_size]), - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> "zzzzzeeeeeee" + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Pass by name to compile. + causal_lm.compile(sampler="contrastive") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.ContrastiveSampler(k=5) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 8e178b7468..ee8a6ecc2d 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -15,11 +15,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.GreedySampler") class GreedySampler(Sampler): """Greedy sampler class. @@ -27,29 +24,18 @@ class GreedySampler(Sampler): This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token. - Call arguments: - {{call_args}} - Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.GreedySampler()( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzaaaaaaa'] + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Pass by name to compile. + causal_lm.compile(sampler="greedy") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.GreedySampler() + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py index b922d29b2a..1ff39c9f9b 100644 --- a/keras_nlp/samplers/random_sampler.py +++ b/keras_nlp/samplers/random_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.RandomSampler") class RandomSampler(Sampler): """Random Sampler class. @@ -37,24 +34,16 @@ class RandomSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, state, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, state + # Pass by name to compile. + causal_lm.compile(sampler="random") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.RandomSampler()( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzcpnjqij'] + # Pass by object to compile. + sampler = keras_nlp.samplers.RandomSampler(temperature=0.7) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index e28fbe9d6e..2101c9277d 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -17,33 +17,8 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.backend import random -from keras_nlp.utils.python_utils import format_docstring - -call_args_docstring = """next: A function which takes in the - `prompt, cache, index` of the current generation loop, and outputs - a tuple `(logits, hidden_states, cache)` with `logits` being the - logits of next token, `hidden_states` being the representation of - the next token, and `cache` for next iteration. - prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This - tensor will be iteratively updated column by column with new sampled - values, starting at `index`. - cache: Optional. A tensor or nested structure of tensors that will be - updated by each call to `next`. This can be used to cache - computations from early iterations of the generative loop. - index: Optional. The first index of `prompt` to start sampling at. - Usually this is set as the length of the shortest non-padded - sequence in `prompt`. - mask: Optional. A 2D integer tensor with the same shape as `prompt`. - Locations which are `True` in the mask are never updated during - sampling. Usually used to mark all locations in the dense prompt - tensor which were present in a user input. - end_token_id: Optional. The token marking the end of the sequence. If - specified, sampling will stop as soon as all sequences in the prompt - produce a `end_token_id` in a location where `mask` is `False`. -""" - - -@format_docstring(call_args=call_args_docstring) + + @keras_nlp_export("keras_nlp.samplers.Sampler") class Sampler: """Base sampler class. @@ -57,35 +32,32 @@ class Sampler: {{call_args}} This base class can be extended to implement different auto-regressive - sampling methods. Subclasses can either: - - - Override the `get_next_token()` method, which computes the next token - based on a probability distribution over all possible vocab entries. - - Override `__call__`, if the sampling method needs additional information - beyond the next tokens probability distribution to sample a sequence. - - Please check available subclass samplers for examples. + sampling methods. To do so, override the `get_next_token()` method, which + computes the next token based on a probability distribution over all + possible vocab entries. Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - # return a uniform distribution over our alphabet. - logits = ops.ones((batch_size, vocab_size)) - return logits, None, cache - - output = keras_nlp.samplers.GreedySampler()( - next=next, - prompt=ops.fill((batch_size, length,), char_lookup['z']), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzaaaaaaa'] + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Greedy search with some tokens forbidden. + class CustomSampler(keras_nlp.samplers.Sampler): + def __init__(self, forbidden_tokens, **kwargs): + super().__init__(**kwargs) + self.forbidden_tokens = forbidden_tokens + + def get_next_token(self, probs): + batch_size, vocab_size = keras.ops.shape(probs) + for id in self.forbidden_tokens: + update = keras.ops.zeros((batch_size, 1)) + probs = keras.ops.slice_update(probs, (0, id), update) + return keras.ops.argmax(probs, axis=-1) + + # 257 = "a" with a leading space, 262 = "the" with a leading space. + causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262])) + causal_lm.summary() + causal_lm.generate(["That's strange"]) ``` """ diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 3456694848..513dd738c7 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopKSampler") class TopKSampler(Sampler): """Top-K Sampler class. @@ -38,24 +35,16 @@ class TopKSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache + # Pass by name to compile. + causal_lm.compile(sampler="top_k") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.TopKSampler(k=3)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtypes="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzacbbcaa'] + # Pass by object to compile. + sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index a04b39aa2b..326f5797a6 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopPSampler") class TopPSampler(Sampler): """Top-P Sampler class. @@ -46,24 +43,16 @@ class TopPSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache + # Pass by name to compile. + causal_lm.compile(sampler="top_p") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.TopPSampler(p=0.1)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzbabcccb'] + # Pass by object to compile. + sampler = keras_nlp.samplers.TopPSampler(p=0.1, k=1_000) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """