diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 70a2e49d..72ea9109 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -78,20 +78,21 @@ def _get_response(logger, synth_example): return parts[1].strip() if len(parts) == 2 else parts[0].strip() -def _gen_train_data(logger, machine_instruction_data, output_file_train): +def _gen_train_data(logger, output_datasets, output_file_train): train_data = [] - for synth_example in machine_instruction_data: - logger.debug(synth_example) - user = _get_question(logger, synth_example) - if len(synth_example.get("context", "")) > 0: - user += "\n" + synth_example["context"] - train_data.append( - { - "system": _SYS_PROMPT, - "user": _unescape(user), - "assistant": _unescape(_get_response(logger, synth_example)), - } - ) + for output_dataset in output_datasets: + for synth_example in output_dataset: + logger.debug(synth_example) + user = _get_question(logger, synth_example) + if len(synth_example.get("context", "")) > 0: + user += "\n" + synth_example["context"] + train_data.append( + { + "system": _SYS_PROMPT, + "user": _unescape(user), + "assistant": _unescape(_get_response(logger, synth_example)), + } + ) with open(output_file_train, "w", encoding="utf-8") as outfile: for entry in train_data: @@ -281,9 +282,9 @@ def generate_data( logger.debug("Dataset: %s" % ds) new_generated_data = sdg.generate(ds) generated_data = ( - new_generated_data + [new_generated_data] if generated_data is None - else generated_data + new_generated_data + else generated_data + [new_generated_data] ) logger.info("Generated %d samples" % len(generated_data)) logger.debug("Generated data: %s" % generated_data)