Skip to content

Commit

Permalink
Merge pull request #15 from makelinux/gen_test_data
Browse files Browse the repository at this point in the history
offshoot gen_test_data() from very long generate_data()
  • Loading branch information
russellb authored Jun 24, 2024
2 parents f3090c1 + c0f7320 commit cba3a62
Showing 1 changed file with 97 additions and 74 deletions.
171 changes: 97 additions & 74 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,53 +362,19 @@ def read_taxonomy(*args, **kwargs):
return instructlab.utils.read_taxonomy(*args, **kwargs)


def generate_data(
logger,
api_base,
tls_insecure,
model_family: str,
yaml_rules: Optional[str] = None,
output_dir: Optional[str] = None,
taxonomy: Optional[str] = None,
taxonomy_base: Optional[str] = None,
prompt_file_path: Optional[str] = None,
model_name: Optional[str] = None,
num_cpus: Optional[int] = None,
num_instructions_to_generate: Optional[int] = None,
num_prompt_instructions=2,
request_batch_size=5,
temperature=1.0,
top_p=1.0,
rouge_threshold: Optional[float] = None,
console_output=True,
api_key: Optional[str] = None,
chunk_word_count=None,
server_ctx_size=None,
tls_client_cert: Optional[str] = None,
tls_client_key: Optional[str] = None,
tls_client_passwd: Optional[str] = None,
):
seed_instruction_data = []
machine_seed_instruction_data = []
generate_start = time.time()
def unescape(s):
return bytes(s, "utf-8").decode("utf-8")

if not os.path.exists(output_dir):
os.mkdir(output_dir)

# check taxonomy first then seed_tasks_path
# throw an error if both not found
# pylint: disable=broad-exception-caught,raise-missing-from
if taxonomy and os.path.exists(taxonomy):
seed_instruction_data = read_taxonomy(
logger, taxonomy, taxonomy_base, yaml_rules
)
else:
raise SystemExit(f"Error: taxonomy ({taxonomy}) does not exist.")

prompt_template = check_prompt_file(
prompt_file_path, get_model_family(model_family, model_name)
)
max_seed_tokens = max_seed_example_tokens(server_ctx_size, len(prompt_template))
def _gen_test_data(
logger,
seed_instruction_data,
max_seed_tokens,
taxonomy,
chunk_word_count,
server_ctx_size,
output_file_test,
):
max_seed_chars = num_chars_from_tokens(max_seed_tokens)
for seed_example in seed_instruction_data:
if (
Expand All @@ -426,9 +392,6 @@ def generate_data(
if not seeds:
raise SystemExit("Nothing to generate. Exiting.")

def unescape(s):
return bytes(s, "utf-8").decode("utf-8")

test_data = []
for seed_example in seed_instruction_data:
user = seed_example["instruction"]
Expand Down Expand Up @@ -457,6 +420,80 @@ def unescape(s):
fg="red",
)
raise click.exceptions.Exit(1)
# utils.jdump(test_data, os.path.join(output_dir, output_file_test))
with open(output_file_test, "w", encoding="utf-8") as outfile:
for entry in test_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")


def _gen_train_data(machine_instruction_data, output_file_train):
train_data = []
for synth_example in machine_instruction_data:
user = synth_example["instruction"]
if len(synth_example["input"]) > 0:
user += "\n" + synth_example["input"]
train_data.append(
{
"system": utils.get_sysprompt(),
"user": unescape(user),
"assistant": unescape(synth_example["output"]),
}
)
# utils.jdump(train_data, output_file_train)
with open(output_file_train, "w", encoding="utf-8") as outfile:
for entry in train_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")


def generate_data(
logger,
api_base,
tls_insecure,
model_family: str,
yaml_rules: Optional[str] = None,
output_dir: Optional[str] = None,
taxonomy: Optional[str] = None,
taxonomy_base: Optional[str] = None,
prompt_file_path: Optional[str] = None,
model_name: Optional[str] = None,
num_cpus: Optional[int] = None,
num_instructions_to_generate: Optional[int] = None,
num_prompt_instructions=2,
request_batch_size=5,
temperature=1.0,
top_p=1.0,
rouge_threshold: Optional[float] = None,
console_output=True,
api_key: Optional[str] = None,
chunk_word_count=None,
server_ctx_size=None,
tls_client_cert: Optional[str] = None,
tls_client_key: Optional[str] = None,
tls_client_passwd: Optional[str] = None,
):
seed_instruction_data = []
machine_seed_instruction_data = []
generate_start = time.time()

if not os.path.exists(output_dir):
os.mkdir(output_dir)

# check taxonomy first then seed_tasks_path
# throw an error if both not found
# pylint: disable=broad-exception-caught,raise-missing-from
if taxonomy and os.path.exists(taxonomy):
seed_instruction_data = read_taxonomy(
logger, taxonomy, taxonomy_base, yaml_rules
)
else:
raise SystemExit(f"Error: taxonomy ({taxonomy}) does not exist.")

prompt_template = check_prompt_file(
prompt_file_path, get_model_family(model_family, model_name)
)
max_seed_tokens = max_seed_example_tokens(server_ctx_size, len(prompt_template))

name = Path(model_name).stem # Just in case it is a file path
date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_")
Expand All @@ -466,6 +503,15 @@ def unescape(s):
output_file_discarded = os.path.join(
output_dir, f"discarded_{name}_{date_suffix}.log"
)
_gen_test_data(
logger,
seed_instruction_data,
max_seed_tokens,
taxonomy,
chunk_word_count,
server_ctx_size,
os.path.join(output_dir, output_file_test),
)
logger.debug(f"Generating to: {os.path.join(output_dir, output_file)}")

request_idx = 0
Expand Down Expand Up @@ -580,32 +626,9 @@ def unescape(s):
f"Generated {total} instructions(discarded {discarded}), rouged {total - keep}, kept {keep} instructions"
)
utils.jdump(machine_instruction_data, os.path.join(output_dir, output_file))
train_data = []
for synth_example in machine_instruction_data:
user = synth_example["instruction"]
if len(synth_example["input"]) > 0:
user += "\n" + synth_example["input"]
train_data.append(
{
"system": utils.get_sysprompt(),
"user": unescape(user),
"assistant": unescape(synth_example["output"]),
}
)
# utils.jdump(train_data, os.path.join(output_dir, output_file_train))
with open(
os.path.join(output_dir, output_file_train), "w", encoding="utf-8"
) as outfile:
for entry in train_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")
# utils.jdump(test_data, os.path.join(output_dir, output_file_test))
with open(
os.path.join(output_dir, output_file_test), "w", encoding="utf-8"
) as outfile:
for entry in test_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")
_gen_train_data(
machine_instruction_data, os.path.join(output_dir, output_file_train)
)

progress_bar.close()

Expand Down

0 comments on commit cba3a62

Please sign in to comment.