Skip to content

Commit

Permalink
added generate to rag
Browse files Browse the repository at this point in the history
  • Loading branch information
mihir86 committed Mar 28, 2024
1 parent 20c395e commit 60ab96f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
16 changes: 13 additions & 3 deletions examples/rag/rag_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"5it [00:00, 6295.86it/s]\n"
"5it [00:00, 37786.52it/s]\n"
]
},
{
Expand All @@ -101,7 +101,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:02<00:00, 2.04s/it]"
"100%|██████████| 1/1 [00:01<00:00, 1.83s/it]"
]
},
{
Expand Down Expand Up @@ -147,7 +147,17 @@
},
"outputs": [],
"source": [
"ragLM = RAGLanguageModel(base=language_model, datastore=datastore)"
"rag_LM = RAGLanguageModel(base=language_model, datastore=datastore)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "456ecd6c",
"metadata": {},
"outputs": [],
"source": [
"rag_LM.generate(condition=None)"
]
},
{
Expand Down
18 changes: 18 additions & 0 deletions llments/lm/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,23 @@ def __init__(self, base: LanguageModel, datastore: Datastore):
LanguageModel: The enhanced language model.
"""

def generate(self, condition: str | None, do_sample: bool = False, max_length: int | None = None, temperature: float = 1, num_return_sequences: int = 1) -> list[str]:
"""Generate an output given the language model.
Args:
condition: The conditioning sequence for the output.
If None, the output is not conditioned.
do_sample: Whether to use sampling or greedy decoding.
max_length: The maximum length of the output sequence,
(defaults to model max).
temperature: The value used to module the next token probabilities.
num_return_sequences: The number of independently computed returned
sequences for each element in the batch.
Returns:
str: Sampled output sequences from the language model.
"""
pass



0 comments on commit 60ab96f

Please sign in to comment.