From 38c424169a65f53f4e41b51b55d0d56d1b25ffb2 Mon Sep 17 00:00:00 2001 From: Oindrilla Chatterjee Date: Fri, 5 Jul 2024 14:15:36 -0400 Subject: [PATCH] align test data with messages format and save Signed-off-by: Oindrilla Chatterjee --- src/instructlab/sdg/generate_data.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 815c3f7a..0915c751 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -142,8 +142,10 @@ def _gen_train_data( def _gen_test_data( leaf_nodes, output_file_test, + output_file_messages, ): test_data = [] + messages_data = [] for _, leaf_node in leaf_nodes.items(): for seed_example in leaf_node: user = seed_example["instruction"] # question @@ -158,12 +160,23 @@ def _gen_test_data( "assistant": _unescape(seed_example["output"]), # answer } ) + sample = { + "inputs": _unescape(user), + "targets": _unescape(seed_example["output"]), + "system": get_sysprompt(), + } + messages_data.append(_convert_to_messages(sample)) with open(output_file_test, "w", encoding="utf-8") as outfile: for entry in test_data: json.dump(entry, outfile, ensure_ascii=False) outfile.write("\n") + with open(output_file_messages, "w", encoding="utf-8") as outfile: + for entry in messages_data: + json.dump(entry, outfile, ensure_ascii=False) + outfile.write("\n") + def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched): knowledge_flow_types = [] @@ -261,11 +274,13 @@ def generate_data( output_file_generated = f"generated_{name}_{date_suffix}.json" output_file_test = f"test_{name}_{date_suffix}.jsonl" output_file_train = f"train_{name}_{date_suffix}.jsonl" - output_file_messages = f"messages_{name}_{date_suffix}.jsonl" + output_file_messages_train = f"train_messages_{name}_{date_suffix}.jsonl" + output_file_messages_test = f"test_messages_{name}_{date_suffix}.jsonl" _gen_test_data( leaf_nodes, os.path.join(output_dir, output_file_test), + os.path.join(output_dir, output_file_messages_test), ) logger.debug(f"Generating to: {os.path.join(output_dir, output_file_generated)}") @@ -336,7 +351,7 @@ def generate_data( logger, generated_data, os.path.join(output_dir, output_file_train), - os.path.join(output_dir, output_file_messages), + os.path.join(output_dir, output_file_messages_train), ) # TODO @@ -344,12 +359,6 @@ def generate_data( # I believe the github bot assumes it is present for presenting generated data to a taxonomy # reviewer or contributor. Otherwise, I don't see a consumer of it in this repo or the # `ilab` CLI. - _gen_train_data( - logger, - generated_data, - os.path.join(output_dir, output_file_generated), - os.path.join(output_dir, output_file_messages), - ) generate_duration = time.time() - generate_start logger.info(f"Generation took {generate_duration:.2f}s")