Skip to content

Commit

Permalink
Support batched inference
Browse files Browse the repository at this point in the history
  • Loading branch information
jncraton committed Dec 28, 2023
1 parent 79ebb32 commit 9d06ee2
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 25 deletions.
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
- Improved search speed when searching many documents
- Reduce memory usage for large document embeddings

### Added

- Support static batching by passing lists to `do`

## 0.12.0 - 2023-12-02

### Changed
Expand Down
44 changes: 31 additions & 13 deletions languagemodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime
import json
import re
from typing import overload

from languagemodels.config import config
from languagemodels.inference import (
Expand Down Expand Up @@ -34,23 +35,34 @@ def complete(prompt: str) -> str:
"""

result = generate(
"Write a sentence", prefix=prompt,
["Write a sentence"], prefix=prompt,
max_tokens=config["max_tokens"], temperature=0.7, topk=40
)
)[0]

if result.startswith(prompt):
prefix_length = len(prompt)
return result[prefix_length:]
else:
return result

@overload
def do(prompt: list) -> list:
...

@overload
def do(prompt: str) -> str:
...

def do(prompt):
"""Follow a single-turn instructional prompt
:param prompt: Instructional prompt to follow
:param prompt: Instructional prompt(s) to follow
:return: Completion returned from the language model
Note that this function is overloaded to return a list of results if
a list if of prompts is provided and a single string if a single
prompt is provided as a string
Examples:
>>> do("Translate Spanish to English: Hola mundo!") #doctest: +SKIP
Expand All @@ -61,17 +73,23 @@ def do(prompt: str) -> str:
>>> do("Is the following positive or negative: I love Star Trek.")
'Positive.'
>>> do(["Pick the sport from the list: baseball, texas, chemistry"] * 2)
['Baseball.', 'Baseball.']
"""
result = generate(prompt, max_tokens=config["max_tokens"], topk=1)

if len(result.split()) == 1:
result = result.title()
prompts = [prompt] if isinstance(prompt, str) else prompt

if result[-1] not in (".", "!", "?"):
result = result + "."
results = generate(prompts, max_tokens=config["max_tokens"], topk=1)

return result
for i, result in enumerate(results):
if len(result.split()) == 1:
results[i] = result.title()

if result[-1] not in (".", "!", "?"):
results[i] = results[i] + "."

return results[0] if isinstance(prompt, str) else results

def chat(prompt: str) -> str:
"""Get new message from chat-optimized language model
Expand Down Expand Up @@ -153,14 +171,14 @@ def chat(prompt: str) -> str:
prompt = prompt[7:].strip()

response = generate(
prompt,
[prompt],
max_tokens=config["max_tokens"],
repetition_penalty=1.3,
temperature=0.3,
topk=40,
prefix="Assistant:",
suppress=suppress,
)
)[0]

# Remove duplicate assistant being generated
if response.startswith("Assistant:"):
Expand All @@ -186,7 +204,7 @@ def code(prompt: str) -> str:
>>> code("def return_4():")
'...return 4...'
"""
result = generate(prompt, max_tokens=config["max_tokens"], topk=1, model="code")
result = generate([prompt], max_tokens=config["max_tokens"], topk=1, model="code")[0]

return result

Expand All @@ -212,7 +230,7 @@ def extract_answer(question: str, context: str) -> str:
'...Guido van Rossum...'
"""

return generate(f"{context}\n\n{question}")
return generate([f"{context}\n\n{question}"])[0]


def classify(doc: str, label1: str, label2: str) -> str:
Expand Down
24 changes: 12 additions & 12 deletions languagemodels/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def chat_oa(engine, prompt, max_tokens=200, temperature=0):


def generate(
instruction,
instructions,
max_tokens=200,
temperature=0.1,
topk=1,
Expand All @@ -120,7 +120,7 @@ def generate(
suppress=[],
model="instruct",
):
"""Generates one completion for a prompt using an instruction-tuned model
"""Generates completions for a prompt
This may use a local model, or it may make an API call to an external
model if API keys are available.
Expand All @@ -141,25 +141,26 @@ def generate(

fmt = model_info.get("prompt_fmt", "{instruction}")

prompt = fmt.replace("{instruction}", instruction)
prompts = [fmt.replace("{instruction}", inst) for inst in instructions]

outputs_ids = []
if hasattr(model, "translate_batch"):
results = model.translate_batch(
[tokenizer.encode(prompt).tokens],
target_prefix=[tokenizer.encode(prefix, add_special_tokens=False).tokens],
[tokenizer.encode(p).tokens for p in prompts],
target_prefix=[tokenizer.encode(prefix, add_special_tokens=False).tokens] * len(prompts),
repetition_penalty=repetition_penalty,
max_decoding_length=max_tokens,
sampling_temperature=temperature,
sampling_topk=topk,
suppress_sequences=suppress,
beam_size=1,
)
output_tokens = results[0].hypotheses[0]
output_ids = [tokenizer.token_to_id(t) for t in output_tokens]
text = tokenizer.decode(output_ids, skip_special_tokens=True)
outputs_tokens = [r.hypotheses[0] for r in results]
for output in outputs_tokens:
outputs_ids.append([tokenizer.token_to_id(t) for t in output])
else:
results = model.generate_batch(
[tokenizer.encode(prompt).tokens],
[tokenizer.encode(p).tokens for p in prompts],
repetition_penalty=repetition_penalty,
max_length=max_tokens,
sampling_temperature=temperature,
Expand All @@ -168,10 +169,9 @@ def generate(
beam_size=1,
include_prompt_in_result=False,
)
output_ids = results[0].sequences_ids[0]
text = tokenizer.decode(output_ids, skip_special_tokens=True).lstrip()
outputs_ids = results[0].sequences_ids[0]

return text
return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids]


def rank_instruct(input, targets):
Expand Down

0 comments on commit 9d06ee2

Please sign in to comment.