From 9d06ee2af640fcf6d65576f52ea9286f1057f278 Mon Sep 17 00:00:00 2001 From: Jon Craton Date: Thu, 28 Dec 2023 09:13:24 -0500 Subject: [PATCH] Support batched inference --- changelog.md | 4 ++++ languagemodels/__init__.py | 44 ++++++++++++++++++++++++++----------- languagemodels/inference.py | 24 ++++++++++---------- 3 files changed, 47 insertions(+), 25 deletions(-) diff --git a/changelog.md b/changelog.md index 45e3714..2c323a8 100644 --- a/changelog.md +++ b/changelog.md @@ -7,6 +7,10 @@ - Improved search speed when searching many documents - Reduce memory usage for large document embeddings +### Added + +- Support static batching by passing lists to `do` + ## 0.12.0 - 2023-12-02 ### Changed diff --git a/languagemodels/__init__.py b/languagemodels/__init__.py index a0bfa78..b5d146e 100644 --- a/languagemodels/__init__.py +++ b/languagemodels/__init__.py @@ -2,6 +2,7 @@ import datetime import json import re +from typing import overload from languagemodels.config import config from languagemodels.inference import ( @@ -34,9 +35,9 @@ def complete(prompt: str) -> str: """ result = generate( - "Write a sentence", prefix=prompt, + ["Write a sentence"], prefix=prompt, max_tokens=config["max_tokens"], temperature=0.7, topk=40 - ) + )[0] if result.startswith(prompt): prefix_length = len(prompt) @@ -44,13 +45,24 @@ def complete(prompt: str) -> str: else: return result +@overload +def do(prompt: list) -> list: + ... +@overload def do(prompt: str) -> str: + ... + +def do(prompt): """Follow a single-turn instructional prompt - :param prompt: Instructional prompt to follow + :param prompt: Instructional prompt(s) to follow :return: Completion returned from the language model + Note that this function is overloaded to return a list of results if + a list if of prompts is provided and a single string if a single + prompt is provided as a string + Examples: >>> do("Translate Spanish to English: Hola mundo!") #doctest: +SKIP @@ -61,17 +73,23 @@ def do(prompt: str) -> str: >>> do("Is the following positive or negative: I love Star Trek.") 'Positive.' + + >>> do(["Pick the sport from the list: baseball, texas, chemistry"] * 2) + ['Baseball.', 'Baseball.'] """ - result = generate(prompt, max_tokens=config["max_tokens"], topk=1) - if len(result.split()) == 1: - result = result.title() + prompts = [prompt] if isinstance(prompt, str) else prompt - if result[-1] not in (".", "!", "?"): - result = result + "." + results = generate(prompts, max_tokens=config["max_tokens"], topk=1) - return result + for i, result in enumerate(results): + if len(result.split()) == 1: + results[i] = result.title() + if result[-1] not in (".", "!", "?"): + results[i] = results[i] + "." + + return results[0] if isinstance(prompt, str) else results def chat(prompt: str) -> str: """Get new message from chat-optimized language model @@ -153,14 +171,14 @@ def chat(prompt: str) -> str: prompt = prompt[7:].strip() response = generate( - prompt, + [prompt], max_tokens=config["max_tokens"], repetition_penalty=1.3, temperature=0.3, topk=40, prefix="Assistant:", suppress=suppress, - ) + )[0] # Remove duplicate assistant being generated if response.startswith("Assistant:"): @@ -186,7 +204,7 @@ def code(prompt: str) -> str: >>> code("def return_4():") '...return 4...' """ - result = generate(prompt, max_tokens=config["max_tokens"], topk=1, model="code") + result = generate([prompt], max_tokens=config["max_tokens"], topk=1, model="code")[0] return result @@ -212,7 +230,7 @@ def extract_answer(question: str, context: str) -> str: '...Guido van Rossum...' """ - return generate(f"{context}\n\n{question}") + return generate([f"{context}\n\n{question}"])[0] def classify(doc: str, label1: str, label2: str) -> str: diff --git a/languagemodels/inference.py b/languagemodels/inference.py index 3d44777..b446a11 100644 --- a/languagemodels/inference.py +++ b/languagemodels/inference.py @@ -111,7 +111,7 @@ def chat_oa(engine, prompt, max_tokens=200, temperature=0): def generate( - instruction, + instructions, max_tokens=200, temperature=0.1, topk=1, @@ -120,7 +120,7 @@ def generate( suppress=[], model="instruct", ): - """Generates one completion for a prompt using an instruction-tuned model + """Generates completions for a prompt This may use a local model, or it may make an API call to an external model if API keys are available. @@ -141,12 +141,13 @@ def generate( fmt = model_info.get("prompt_fmt", "{instruction}") - prompt = fmt.replace("{instruction}", instruction) + prompts = [fmt.replace("{instruction}", inst) for inst in instructions] + outputs_ids = [] if hasattr(model, "translate_batch"): results = model.translate_batch( - [tokenizer.encode(prompt).tokens], - target_prefix=[tokenizer.encode(prefix, add_special_tokens=False).tokens], + [tokenizer.encode(p).tokens for p in prompts], + target_prefix=[tokenizer.encode(prefix, add_special_tokens=False).tokens] * len(prompts), repetition_penalty=repetition_penalty, max_decoding_length=max_tokens, sampling_temperature=temperature, @@ -154,12 +155,12 @@ def generate( suppress_sequences=suppress, beam_size=1, ) - output_tokens = results[0].hypotheses[0] - output_ids = [tokenizer.token_to_id(t) for t in output_tokens] - text = tokenizer.decode(output_ids, skip_special_tokens=True) + outputs_tokens = [r.hypotheses[0] for r in results] + for output in outputs_tokens: + outputs_ids.append([tokenizer.token_to_id(t) for t in output]) else: results = model.generate_batch( - [tokenizer.encode(prompt).tokens], + [tokenizer.encode(p).tokens for p in prompts], repetition_penalty=repetition_penalty, max_length=max_tokens, sampling_temperature=temperature, @@ -168,10 +169,9 @@ def generate( beam_size=1, include_prompt_in_result=False, ) - output_ids = results[0].sequences_ids[0] - text = tokenizer.decode(output_ids, skip_special_tokens=True).lstrip() + outputs_ids = results[0].sequences_ids[0] - return text + return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids] def rank_instruct(input, targets):