From 3834f6077f475ad81c4bee699b8433bbc0b934bc Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 3 Jul 2024 22:54:42 -0400 Subject: [PATCH] generate_data: fix support for multiple leaf nodes When running generate_data() with a taxonomy with more than one leaf node, the code previously appended the resulting datasets as if they behaved like a Python list. They are a `datasets.Dataset` and do not support the `+` operator. Instead, keep a list of these datasets. Also update the code that writes these results to a file to handle the extra level of iteration now required. Signed-off-by: Russell Bryant --- src/instructlab/sdg/generate_data.py | 31 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 70a2e49d..72ea9109 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -78,20 +78,21 @@ def _get_response(logger, synth_example): return parts[1].strip() if len(parts) == 2 else parts[0].strip() -def _gen_train_data(logger, machine_instruction_data, output_file_train): +def _gen_train_data(logger, output_datasets, output_file_train): train_data = [] - for synth_example in machine_instruction_data: - logger.debug(synth_example) - user = _get_question(logger, synth_example) - if len(synth_example.get("context", "")) > 0: - user += "\n" + synth_example["context"] - train_data.append( - { - "system": _SYS_PROMPT, - "user": _unescape(user), - "assistant": _unescape(_get_response(logger, synth_example)), - } - ) + for output_dataset in output_datasets: + for synth_example in output_dataset: + logger.debug(synth_example) + user = _get_question(logger, synth_example) + if len(synth_example.get("context", "")) > 0: + user += "\n" + synth_example["context"] + train_data.append( + { + "system": _SYS_PROMPT, + "user": _unescape(user), + "assistant": _unescape(_get_response(logger, synth_example)), + } + ) with open(output_file_train, "w", encoding="utf-8") as outfile: for entry in train_data: @@ -281,9 +282,9 @@ def generate_data( logger.debug("Dataset: %s" % ds) new_generated_data = sdg.generate(ds) generated_data = ( - new_generated_data + [new_generated_data] if generated_data is None - else generated_data + new_generated_data + else generated_data + [new_generated_data] ) logger.info("Generated %d samples" % len(generated_data)) logger.debug("Generated data: %s" % generated_data)