Skip to content

Commit

Permalink
Use choices list instead of grammar
Browse files Browse the repository at this point in the history
  • Loading branch information
jncraton committed Dec 29, 2023
1 parent 9ae2a8a commit 1c3dab1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
2 changes: 1 addition & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
### Added

- Support static batching by passing lists to `do`
- Support basic BNF grammar on `do`
- Support choices list on `do` to restrict possible outputs

## 0.12.0 - 2023-12-02

Expand Down
14 changes: 5 additions & 9 deletions languagemodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ def do(prompt: str) -> str:
...


def do(prompt, grammar=None):
def do(prompt, choices=None):
"""Follow a single-turn instructional prompt
:param prompt: Instructional prompt(s) to follow
:param choices: If provided, outputs are restricted to values in choices
:return: Completion returned from the language model
Note that this function is overloaded to return a list of results if
Expand All @@ -80,19 +81,14 @@ def do(prompt, grammar=None):
>>> do(["Pick the sport from the list: baseball, texas, chemistry"] * 2)
['Baseball.', 'Baseball.']
>>> do(["Say red", "Say blue"], 'root ::= "red" | "blue"')
>>> do(["Say red", "Say blue"], choices=["red", "blue"])
['red', 'blue']
"""

prompts = [prompt] if isinstance(prompt, str) else prompt

if grammar:
assert grammar.startswith("root ::= ")
grammar = grammar[len("root ::= "):]
grammar = grammar.strip('"')
targets = grammar.split('" | "')

results = [r[0] for r in rank_instruct(prompts, targets)]
if choices:
results = [r[0] for r in rank_instruct(prompts, choices)]
else:
results = generate(prompts, max_tokens=config["max_tokens"], topk=1)

Expand Down

0 comments on commit 1c3dab1

Please sign in to comment.