From d703fb582b4c691d9b9c90f12bf8ea5f50373937 Mon Sep 17 00:00:00 2001 From: degaochu Date: Fri, 21 Jun 2024 10:16:48 +0800 Subject: [PATCH] fix: use all prompts in a batch to generate data The original code in openai_completion() only use the first prompt and missed last 4 of a batch(5) prompts. The fix is adding an extra loop to use all prompts of the batch to call openai chat completion api to generate data. Signed-off-by: degaochu --- src/instructlab/sdg/utils.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/instructlab/sdg/utils.py b/src/instructlab/sdg/utils.py index 83423fe8..e92d772a 100644 --- a/src/instructlab/sdg/utils.py +++ b/src/instructlab/sdg/utils.py @@ -147,23 +147,23 @@ def openai_completion( f"Model {model_name} is not served by the server. These are the served models {model_ids}" ) - messages = [ - {"role": "system", "content": get_sysprompt()}, - {"role": "user", "content": prompt_batch[batch_id]}, - ] - - # Inference the model - try: - response = client.chat.completions.create( - messages=messages, - **shared_kwargs, - ) - except OpenAIError as exc: - raise GenerateException( - f"There was a problem connecting to the server {exc}" - ) from exc + for prompt in prompt_batch: + messages = [ + {"role": "system", "content": get_sysprompt()}, + {"role": "user", "content": prompt}, + ] + # Inference the model + try: + response = client.chat.completions.create( + messages=messages, + **shared_kwargs, + ) + except OpenAIError as exc: + raise GenerateException( + f"There was a problem connecting to the server {exc}" + ) from exc - completions.extend(response.choices) + completions.extend(response.choices) if return_text: completions = [completion.text for completion in completions]