Skip to content

Commit

Permalink
save an additional train dataset in the converted format
Browse files Browse the repository at this point in the history
Signed-off-by: Oindrilla Chatterjee <[email protected]>
  • Loading branch information
oindrillac committed Jul 5, 2024
1 parent 4233300 commit 044849d
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _get_response(logger, synth_example):
logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
return parts[1].strip() if len(parts) == 2 else parts[0].strip()


def _convert_to_messages(sample):
"""
Convert a sample dictionary to contain 'messages' and 'metadata' columns required for training.
Expand All @@ -90,38 +91,53 @@ def _convert_to_messages(sample):
{"content": user_query, "role": "user"},
{"content": sample["targets"], "role": "assistant"},
]
metadata = {key: value for key, value in sample.items() if key not in ["messages", "inputs", "targets"]}
metadata = {
key: value
for key, value in sample.items()
if key not in ["messages", "inputs", "targets"]
}
sample["metadata"] = json.dumps(metadata)

# keeping required keys for messages training format
sample = {
"messages": sample["messages"],
"metadata": sample["metadata"]
}
sample = {"messages": sample["messages"], "metadata": sample["metadata"]}

return sample


def _gen_train_data(logger, machine_instruction_data, output_file_train):
def _gen_train_data(
logger, machine_instruction_data, output_file_train, output_file_messages
):
train_data = []
messages_data = []
for synth_example in machine_instruction_data:
logger.debug(synth_example)
user = _get_question(logger, synth_example)
if len(synth_example.get("context", "")) > 0:
user += "\n" + synth_example["context"]
train_data.append(
{
"system": get_sysprompt(),
"user": _unescape(user),
"assistant": _unescape(_get_response(logger, synth_example)),
}
)
assistant = _unescape(_get_response(logger, synth_example))
train_entry = {
"system": get_sysprompt(),
"user": _unescape(user),
"assistant": assistant,
}
train_data.append(train_entry)
sample = {
"inputs": _unescape(user),
"targets": assistant,
"system": get_sysprompt(),
}
messages_data.append(_convert_to_messages(sample))

with open(output_file_train, "w", encoding="utf-8") as outfile:
for entry in train_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")

with open(output_file_messages, "w", encoding="utf-8") as outfile:
for entry in messages_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")


def _gen_test_data(
leaf_nodes,
Expand Down Expand Up @@ -245,12 +261,12 @@ def generate_data(
output_file_generated = f"generated_{name}_{date_suffix}.json"
output_file_test = f"test_{name}_{date_suffix}.jsonl"
output_file_train = f"train_{name}_{date_suffix}.jsonl"
output_file_messages = f"messages_{name}_{date_suffix}.jsonl"

_gen_test_data(
leaf_nodes,
os.path.join(output_dir, output_file_test),
)

logger.debug(f"Generating to: {os.path.join(output_dir, output_file_generated)}")

orig_cert = (tls_client_cert, tls_client_key, tls_client_passwd)
Expand All @@ -269,6 +285,7 @@ def generate_data(

# TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI
# about whether we can turn this on (whether vllm is used or not)

batched = False

sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(
Expand Down Expand Up @@ -315,15 +332,23 @@ def generate_data(
if generated_data is None:
generated_data = []

_gen_train_data(logger, generated_data, os.path.join(output_dir, output_file_train))
_gen_train_data(
logger,
generated_data,
os.path.join(output_dir, output_file_train),
os.path.join(output_dir, output_file_messages),
)

# TODO
# This is for backwards compatibility. The file existing previously, so we'll keep it for now.
# I believe the github bot assumes it is present for presenting generated data to a taxonomy
# reviewer or contributor. Otherwise, I don't see a consumer of it in this repo or the
# `ilab` CLI.
_gen_train_data(
logger, generated_data, os.path.join(output_dir, output_file_generated)
logger,
generated_data,
os.path.join(output_dir, output_file_generated),
os.path.join(output_dir, output_file_messages),
)

generate_duration = time.time() - generate_start
Expand Down

0 comments on commit 044849d

Please sign in to comment.