Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use all prompts in a batch to generate data #31

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/instructlab/sdg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,23 +147,23 @@ def openai_completion(
f"Model {model_name} is not served by the server. These are the served models {model_ids}"
)

messages = [
{"role": "system", "content": get_sysprompt()},
{"role": "user", "content": prompt_batch[batch_id]},
]

# Inference the model
try:
response = client.chat.completions.create(
messages=messages,
**shared_kwargs,
)
except OpenAIError as exc:
raise GenerateException(
f"There was a problem connecting to the server {exc}"
) from exc
for prompt in prompt_batch:
messages = [
{"role": "system", "content": get_sysprompt()},
{"role": "user", "content": prompt},
]
# Inference the model
try:
response = client.chat.completions.create(
messages=messages,
**shared_kwargs,
)
except OpenAIError as exc:
raise GenerateException(
f"There was a problem connecting to the server {exc}"
) from exc
Comment on lines -150 to +164
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code seems to include the whole batch in a request, while the new code seems to do a request per prompt within that batch. The old code seems to match the description of the batch_size parameter:

batch_size: Number of prompts to send in a single request

Was it not actually working this way?

Or what's the motivation for the change?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe the original code meant to do something equivalent to prompt_batches[batch_id] instead of prompt_batch[batch_id] ? Indexing prompt_batch with batch_id doesn't make sense.

Copy link
Author

@chudegao chudegao Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

openai api client.chat.completions.create() and llama-cpp-python api server does not support batch so it supposed to call the api per prompt in the batch. Current implementation does not do as expected. I think it's a defect. It wastes 4/5 prompts in a batch. The impact is it requires 5 more times sampling.

Some background:
The code should be borrowed from https://github.com/tatsu-lab/stanford_alpaca/blob/main/utils.py. The stanford_alpaca call openai api and use openai.Completion.create() , both the server and client support batch.

I believe sdg will support batch(e.g. vllm) in the future to improve efficiency, it's better keep all the batch related code. Before it's supported, an extra loop is needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm expecting another large PR today / this week that will probably remove this code and we'll no longer be based on that stanford_alpaca code. I'd like to keep this open but hold off a little bit to see if the change ends up still being relevant.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.


completions.extend(response.choices)
completions.extend(response.choices)

if return_text:
completions = [completion.text for completion in completions]
Expand Down