-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: use all prompts in a batch to generate data #31
Conversation
The original code in openai_completion() only use the first prompt and missed last 4 of a batch(5) prompts. The fix is adding an extra loop to use all prompts of the batch to call openai chat completion api to generate data. Signed-off-by: degaochu <[email protected]>
messages = [ | ||
{"role": "system", "content": get_sysprompt()}, | ||
{"role": "user", "content": prompt_batch[batch_id]}, | ||
] | ||
|
||
# Inference the model | ||
try: | ||
response = client.chat.completions.create( | ||
messages=messages, | ||
**shared_kwargs, | ||
) | ||
except OpenAIError as exc: | ||
raise GenerateException( | ||
f"There was a problem connecting to the server {exc}" | ||
) from exc | ||
for prompt in prompt_batch: | ||
messages = [ | ||
{"role": "system", "content": get_sysprompt()}, | ||
{"role": "user", "content": prompt}, | ||
] | ||
# Inference the model | ||
try: | ||
response = client.chat.completions.create( | ||
messages=messages, | ||
**shared_kwargs, | ||
) | ||
except OpenAIError as exc: | ||
raise GenerateException( | ||
f"There was a problem connecting to the server {exc}" | ||
) from exc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old code seems to include the whole batch in a request, while the new code seems to do a request per prompt within that batch. The old code seems to match the description of the batch_size
parameter:
batch_size: Number of prompts to send in a single request
Was it not actually working this way?
Or what's the motivation for the change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe the original code meant to do something equivalent to prompt_batches[batch_id]
instead of prompt_batch[batch_id]
? Indexing prompt_batch
with batch_id
doesn't make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
openai api client.chat.completions.create() and llama-cpp-python api server does not support batch so it supposed to call the api per prompt in the batch. Current implementation does not do as expected. I think it's a defect. It wastes 4/5 prompts in a batch. The impact is it requires 5 more times sampling.
Some background:
The code should be borrowed from https://github.com/tatsu-lab/stanford_alpaca/blob/main/utils.py. The stanford_alpaca call openai api and use openai.Completion.create() , both the server and client support batch.
I believe sdg will support batch(e.g. vllm) in the future to improve efficiency, it's better keep all the batch related code. Before it's supported, an extra loop is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm expecting another large PR today / this week that will probably remove this code and we'll no longer be based on that stanford_alpaca code. I'd like to keep this open but hold off a little bit to see if the change ends up still being relevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
I appreciate your efforts and for the PR! The code changed here was removed from the repo, though, so I'm going to close this out. |
The original code in openai_completion() only use the first prompt and missed last 4 of a batch(default is 5) prompts.
The fix is adding an extra loop to use all prompts of the batch to call openai chat completion api to generate data.