Skip to content

Commit

Permalink
Merge pull request #85 from russellb/fix-multiple-output-datasets
Browse files Browse the repository at this point in the history
generate_data: fix support for multiple leaf nodes
  • Loading branch information
russellb authored Jul 8, 2024
2 parents 23f2cb7 + 3834f60 commit d6091ff
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,21 @@ def _get_response(logger, synth_example):
return parts[1].strip() if len(parts) == 2 else parts[0].strip()


def _gen_train_data(logger, machine_instruction_data, output_file_train):
def _gen_train_data(logger, output_datasets, output_file_train):
train_data = []
for synth_example in machine_instruction_data:
logger.debug(synth_example)
user = _get_question(logger, synth_example)
if len(synth_example.get("context", "")) > 0:
user += "\n" + synth_example["context"]
train_data.append(
{
"system": _SYS_PROMPT,
"user": _unescape(user),
"assistant": _unescape(_get_response(logger, synth_example)),
}
)
for output_dataset in output_datasets:
for synth_example in output_dataset:
logger.debug(synth_example)
user = _get_question(logger, synth_example)
if len(synth_example.get("context", "")) > 0:
user += "\n" + synth_example["context"]
train_data.append(
{
"system": _SYS_PROMPT,
"user": _unescape(user),
"assistant": _unescape(_get_response(logger, synth_example)),
}
)

with open(output_file_train, "w", encoding="utf-8") as outfile:
for entry in train_data:
Expand Down Expand Up @@ -281,9 +282,9 @@ def generate_data(
logger.debug("Dataset: %s" % ds)
new_generated_data = sdg.generate(ds)
generated_data = (
new_generated_data
[new_generated_data]
if generated_data is None
else generated_data + new_generated_data
else generated_data + [new_generated_data]
)
logger.info("Generated %d samples" % len(generated_data))
logger.debug("Generated data: %s" % generated_data)
Expand Down

0 comments on commit d6091ff

Please sign in to comment.