diff --git a/examples/rag/rag_example.ipynb b/examples/rag/rag_example.ipynb index 7cbf6d6..c07a234 100644 --- a/examples/rag/rag_example.ipynb +++ b/examples/rag/rag_example.ipynb @@ -87,7 +87,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "5it [00:00, 6295.86it/s]\n" + "5it [00:00, 37786.52it/s]\n" ] }, { @@ -101,7 +101,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1/1 [00:02<00:00, 2.04s/it]" + "100%|██████████| 1/1 [00:01<00:00, 1.83s/it]" ] }, { @@ -147,7 +147,17 @@ }, "outputs": [], "source": [ - "ragLM = RAGLanguageModel(base=language_model, datastore=datastore)" + "rag_LM = RAGLanguageModel(base=language_model, datastore=datastore)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "456ecd6c", + "metadata": {}, + "outputs": [], + "source": [ + "rag_LM.generate(condition=None)" ] }, { diff --git a/llments/lm/rag.py b/llments/lm/rag.py index a850f90..0a0a85b 100644 --- a/llments/lm/rag.py +++ b/llments/lm/rag.py @@ -14,5 +14,23 @@ def __init__(self, base: LanguageModel, datastore: Datastore): LanguageModel: The enhanced language model. """ + def generate(self, condition: str | None, do_sample: bool = False, max_length: int | None = None, temperature: float = 1, num_return_sequences: int = 1) -> list[str]: + """Generate an output given the language model. + + Args: + condition: The conditioning sequence for the output. + If None, the output is not conditioned. + do_sample: Whether to use sampling or greedy decoding. + max_length: The maximum length of the output sequence, + (defaults to model max). + temperature: The value used to module the next token probabilities. + num_return_sequences: The number of independently computed returned + sequences for each element in the batch. + + Returns: + str: Sampled output sequences from the language model. + """ + pass +