From 9ee7f7073c2c70b83a32bd0eb575b2053c23da79 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Jun 2024 16:31:19 -0400 Subject: [PATCH] Use new API in generate_data This makes use of the new SDG API under the generate_data() method used by the CLI. It uses new simple workflows for knowlege and skills that inteded for basic usable with a small model for testing and demo purposes. The full pipelines provided in the library will only work in larger environments capable of running Mixtral-8x7b. There are still various TODOs in the code, but this is enough to start with. I'm sure we will make enhancements to these basic workflows that still work for the small environments. Signed-off-by: Russell Bryant --- .pylintrc | 1 - pyproject.toml | 1 - .../configs/knowledge/simple_generate_qa.yaml | 5 +- .../skills/simple_generate_qa_freeform.yaml | 33 + .../skills/simple_generate_qa_grounded.yaml | 37 + src/instructlab/sdg/default_flows.py | 41 +- src/instructlab/sdg/generate_data.py | 768 +++++------------- src/instructlab/sdg/utils/chunking.py | 45 +- src/instructlab/sdg/utils/openai.py | 175 ---- src/instructlab/sdg/utils/taxonomy.py | 59 ++ tests/test_chunking.py | 2 +- 11 files changed, 385 insertions(+), 782 deletions(-) create mode 100644 src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml create mode 100644 src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml delete mode 100644 src/instructlab/sdg/utils/openai.py diff --git a/.pylintrc b/.pylintrc index 821ae6af..3b4da7a8 100644 --- a/.pylintrc +++ b/.pylintrc @@ -444,7 +444,6 @@ disable=raw-checker-failed, logging-too-many-args, attribute-defined-outside-init, abstract-method, - pointless-statement, wrong-import-order, line-too-long, logging-fstring-interpolation diff --git a/pyproject.toml b/pyproject.toml index d13fd6d0..8178ca06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,6 @@ known-local-folder = ["tuning"] disable_error_code = ["import-not-found", "import-untyped"] exclude = [ "^src/instructlab/sdg/generate_data\\.py$", - "^src/instructlab/sdg/utils/openai\\.py$", "^src/instructlab/sdg/utils/taxonomy\\.py$", "^src/instructlab/sdg/default_flows\\.py$", "^src/instructlab/sdg/llmblock\\.py$", diff --git a/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml b/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml index c20add97..9ad6fa77 100644 --- a/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml +++ b/src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml @@ -28,10 +28,7 @@ examples: | {document} generation: | - Provide a single question and answer pair based on the document: - - Document: - {{document}} + Provide a single question and answer pair based on the document. start_tags: [""] end_tags: [""] diff --git a/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml b/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml new file mode 100644 index 00000000..2913d7df --- /dev/null +++ b/src/instructlab/sdg/configs/skills/simple_generate_qa_freeform.yaml @@ -0,0 +1,33 @@ +system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. + +introduction: Develop a series of question and answer pairs to perform a task. + +principles: | +Here are the requirements: + 1. Try not to repeat the verb for each instruction to maximize diversity. + 2. The language used for the instruction also should be diverse. For example, you should combine questions with imperative instructions. + 3. The type of instructions should be similar to provided examples. The generated instruction and the output should be grounded in the provided document. + 4. A GPT language model should be able to complete the instruction. For example, do not ask the assistant to create any visual or audio output. For another example, do not ask the assistant to wake you up at 5pm or set a reminder because it cannot perform any action. + 5. The instructions should be in English. + 6. The instructions should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted. + 7. The output should be an appropriate response to the input and the instruction. Long outputs are preferable. + +examples: | + The task is {task_description}. + + Here are some examples to help you understand the type of questions that are asked for: + + {question_1} + {response_1} + + {question_2} + {response_2} + + {question_3} + {response_3} + +generation: | + Provide a single question and answer pair based on the examples. + +start_tags: [""] +end_tags: [""] diff --git a/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml b/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml new file mode 100644 index 00000000..fe48c99c --- /dev/null +++ b/src/instructlab/sdg/configs/skills/simple_generate_qa_grounded.yaml @@ -0,0 +1,37 @@ +system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. + +introduction: Develop a series of question and answer pairs to perform a task. + +principles: | +Here are the requirements: + 1. Try not to repeat the verb for each instruction to maximize diversity. + 2. The language used for the instruction also should be diverse. For example, you should combine questions with imperative instructions. + 3. The type of instructions should be similar to provided examples. The generated instruction and the output should be grounded in the provided document. + 4. A GPT language model should be able to complete the instruction. For example, do not ask the assistant to create any visual or audio output. For another example, do not ask the assistant to wake you up at 5pm or set a reminder because it cannot perform any action. + 5. The instructions should be in English. + 6. The instructions should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted. + 7. The output should be an appropriate response to the input and the instruction. Long outputs are preferable. + +examples: | + The task is {task_description}. + + Here is some context for the example questions: + + {context} + + Here are some examples to help you understand the type of questions that are asked for: + + {question_1} + {response_1} + + {question_2} + {response_2} + + {question_3} + {response_3} + +generation: | + Provide a single question and answer pair based on the examples. + +start_tags: [""] +end_tags: [""] diff --git a/src/instructlab/sdg/default_flows.py b/src/instructlab/sdg/default_flows.py index d12ce4ff..b2abed88 100644 --- a/src/instructlab/sdg/default_flows.py +++ b/src/instructlab/sdg/default_flows.py @@ -40,17 +40,15 @@ def get_flow(self) -> list: pass -class SimpleKnowledgeFlow(Flow): +class _SimpleFlow(Flow): def get_flow(self) -> list: sdg_base = resources.files(__package__) return [ { "block_type": LLMBlock, "block_config": { - "block_name": "gen_knowledge", - "config_path": os.path.join( - sdg_base, "configs/knowledge/simple_generate_qa.yaml" - ), + "block_name": "", # must be set by subclass + "config_path": "", # must be set by subclass "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -68,6 +66,39 @@ def get_flow(self) -> list: ] +class SimpleKnowledgeFlow(_SimpleFlow): + def get_flow(self) -> list: + flow = super().get_flow() + sdg_base = resources.files(__package__) + flow[0]["block_config"]["config_path"] = os.path.join( + sdg_base, "configs/knowledge/simple_generate_qa.yaml" + ) + flow[0]["block_config"]["block_name"] = "gen_knowledge" + return flow + + +class SimpleFreeformSkillFlow(_SimpleFlow): + def get_flow(self) -> list: + flow = super().get_flow() + sdg_base = resources.files(__package__) + flow[0]["block_config"]["config_path"] = os.path.join( + sdg_base, "configs/skills/simple_generate_qa_freeform.yaml" + ) + flow[0]["block_config"]["block_name"] = "gen_skill_freeform" + return flow + + +class SimpleGroundedSkillFlow(_SimpleFlow): + def get_flow(self) -> list: + flow = super().get_flow() + sdg_base = resources.files(__package__) + flow[0]["block_config"]["config_path"] = os.path.join( + sdg_base, "configs/skills/simple_generate_qa_grounded.yaml" + ) + flow[0]["block_config"]["block_name"] = "gen_skill_grounded" + return flow + + class MMLUBenchFlow(Flow): def get_flow(self) -> list: sdg_base = resources.files(__package__) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 24703e22..89083f52 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -2,447 +2,163 @@ # Standard from datetime import datetime -from functools import partial from pathlib import Path from typing import Optional import json -import multiprocessing import os -import random -import re -import string import time # Third Party -from jinja2 import Template -from rouge_score import rouge_scorer -import click - -# instructlab - this needs to go away - issue #6 -import instructlab.utils -import tqdm +# instructlab - All of these need to go away (other than sdg) - issue #6 +from datasets import Dataset +from instructlab.utils import get_sysprompt +import httpx +import openai # First Party # pylint: disable=ungrouped-imports -from instructlab.sdg import utils -from instructlab.sdg.utils import chunking -from instructlab.sdg.utils import json as json_utils -from instructlab.sdg.utils import models, openai -from instructlab.sdg.utils import taxonomy as taxonomy_utils - -DEFAULT_PROMPT_TEMPLATE_MERLINITE = """\ -You are asked to come up with a set of 5 diverse task instructions under {{taxonomy}}{{" for the task \\"%s\\""|format(task_description) if task_description}}. These task instructions will be given to a GPT model and we will evaluate the GPT model for completing the instructions. - -Here are the requirements: -1. Try not to repeat the verb for each instruction to maximize diversity. -2. The language used for the instruction also should be diverse. For example, you should combine questions with imperative instructions. -{% if not document -%} -3. The type of instructions should not have topic diversity. The list should follow the same topic and category. -{% else -%} -3. The type of instructions should be similar to provided examples. The generated instruction and the output should be grounded in the provided document. -{% endif -%} -4. A GPT language model should be able to complete the instruction. For example, do not ask the assistant to create any visual or audio output. For another example, do not ask the assistant to wake you up at 5pm or set a reminder because it cannot perform any action. -5. The instructions should be in English. -6. The instructions should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted. -{% if not document -%} -7. You should generate an appropriate input to the instruction. The input field should contain a specific example provided for the instruction. It should involve realistic data and should not contain simple placeholders. The input should provide substantial content to make the instruction challenging but should ideally not exceed 100 words. -8. Not all instructions require input. For example, when an instruction asks about some general information, "what is the highest peak in the world", it is not necessary to provide a specific context. In this case, we simply put "" in the input field. -9. The output should be an appropriate response to the instruction and the input. Make sure the output is less than 100 words. -{% else -%} -7. The output should be an appropriate response to the input and the instruction. Long outputs are preferable. -{% endif %} - -{% if not document -%} -List of 5 tasks: -{% else -%} -Based on below document provide a list of 5 tasks: - -Document: -{{document}} - -Here are some examples to help you understand the type of questions that are asked for this document: -{% endif -%} -""" - -DEFAULT_PROMPT_TEMPLATE_MIXTRAL = """\ - [INST]You are a very knowledgeable AI Assistant that will faithfully assist the user with their task. You are asked to come up with a set of 5 diverse task instructions under {{taxonomy}}{{" for the task \\"%s\\""|format(task_description) if task_description}}. These task instructions will be given to a GPT model and we will evaluate the GPT model for completing the instructions. -Here are the requirements: -1. Try not to repeat the verb for each instruction to maximize diversity. -2. The language used for the instruction also should be diverse. For example, you should combine questions with imperative instructions. -{% if not document -%} -3. The type of instructions should not have topic diversity. The list should follow the same topic and category. -{% else -%} -3. The type of instructions should be similar to provided examples. The generated instruction and the output should be grounded in the provided document. -{% endif -%} -4. A GPT language model should be able to complete the instruction. For example, do not ask the assistant to create any visual or audio output. For another example, do not ask the assistant to wake you up at 5pm or set a reminder because it cannot perform any action. -5. The instructions should be in English. -6. The instructions should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted. -{% if not document -%} -7. You should generate an appropriate input to the instruction. The input field should contain a specific example provided for the instruction. It should involve realistic data and should not contain simple placeholders. The input should provide substantial content to make the instruction challenging but should ideally not exceed 100 words. -8. Not all instructions require input. For example, when an instruction asks about some general information, "what is the highest peak in the world", it is not necessary to provide a specific context. In this case, we simply put "" in the input field. -9. The output should be an appropriate response to the instruction and the input. Make sure the output is less than 100 words. -{% else -%} -7. The output should be an appropriate response to the input and the instruction. Long outputs are preferable. -{% endif %} -{% if not document -%} -List of 5 tasks: -{% else -%} -Based on below document provide a list of 5 tasks: -Document: -{{document}} -Here are some examples to help you understand the type of questions that are asked for this document: -{% endif -%}[/INST] -""" - -_WORD_DENYLIST = [ - "image", - "images", - "graph", - "graphs", - "picture", - "pictures", - "file", - "files", - "map", - "maps", - "draw", - "plot", - "go to", - "video", - "audio", - "music", - "flowchart", - "diagram", -] - - -def check_prompt_file(prompt_file_path, model_family): - """Check for prompt file.""" - try: - with open(prompt_file_path, encoding="utf=8") as file: - prompt_template = file.read() - except FileNotFoundError as exc: - print( - f"Cannot find {prompt_file_path}. Using default prompt depending on model-family." - ) - if model_family == "merlinite": - prompt_template = DEFAULT_PROMPT_TEMPLATE_MERLINITE - elif model_family == "mixtral": - prompt_template = DEFAULT_PROMPT_TEMPLATE_MIXTRAL - else: - raise ValueError(f"Unsupported family '{model_family}': {exc}") from exc - prompt_template = prompt_template.strip() + "\n" - return prompt_template - - -def encode_prompt(prompt_instructions, prompt): - """Encode multiple prompt instructions into a single string. - If documents exist, randomly select one.""" - idx = 0 - document = None - document_list = prompt_instructions[0].get("document") +from instructlab.sdg import SDG, utils +from instructlab.sdg.default_flows import ( + MODEL_FAMILY_MERLINITE, + MODEL_FAMILY_MIXTRAL, + MMLUBenchFlow, + SimpleFreeformSkillFlow, + SimpleGroundedSkillFlow, + SimpleKnowledgeFlow, + SynthKnowledgeFlow, +) +from instructlab.sdg.pipeline import Pipeline +from instructlab.sdg.utils import chunking, models +from instructlab.sdg.utils.taxonomy import ( + leaf_node_to_samples, + read_taxonomy_leaf_nodes, +) + + +def _unescape(s): + return bytes(s, "utf-8").decode("utf-8") - if document_list: - document = random.choice(document_list) - prompt = Template(prompt).render( - taxonomy=prompt_instructions[0]["taxonomy_path"], - task_description=prompt_instructions[0]["task_description"], - document=document, - ) +# This is a hack because the simple workflow returns a q/a pair as a single output. +# We could possibly try to ask for them separately, but it would cost twice the inference +# API calls. All of this is because the smallest models we use on small environments +# for testing and demos weren't good enough to follow the strict formatting instructions used +# in the full pipeline. +def _get_question(logger, synth_example): + if "question" in synth_example: + return synth_example["question"] - # pylint: disable=unused-variable - for idx, task_dict in enumerate(prompt_instructions): - ( - instruction, - prompt_input, - prompt_output, - taxonomy_path, - ) = ( - task_dict["instruction"], - task_dict["input"], - task_dict["output"], - task_dict["taxonomy_path"], - ) - instruction = re.sub(r"\s+", " ", instruction).strip().rstrip(":") - prompt_input = "" if prompt_input.lower() == "" else prompt_input - prompt += f"* Task {idx + 1}\n" - prompt += f"** Instruction\n{instruction}\n" - prompt += f"** Input\n{prompt_input}\n" - prompt += f"** Output\n{prompt_output}\n" - prompt += f"* Task {idx + 2}\n" - return prompt - - -def writeline2file(logfile, line): - t = datetime.now().replace(microsecond=0).isoformat() - with open(logfile, "a", encoding="utf-8") as fp: - fp.write(f"{t} - {line}\n") - - -def post_process_gpt3_response(num_prompt_instructions, response, discarded_file): - if response is None: - return [], 0 - raw_instructions = ( - # https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices - f"* Task {num_prompt_instructions + 1}\n" + response.message.content - ) - raw_instructions = re.split(r"\* Task \d+", raw_instructions) - instructions = [] - discarded = 0 - for inst in raw_instructions: - if not inst.strip(): - continue - - splitted_data = re.split(r"\*\*\s+(Instruction|Input|Output):?", inst) - if len(splitted_data) != 7: - writeline2file( - discarded_file, - "Discarded instruction(didn't match expected format): " + repr(inst), - ) - discarded += 1 - continue - inst = splitted_data[2].strip() - prompt_input = splitted_data[4].strip() - prompt_input = "" if prompt_input.lower() == "" else prompt_input - prompt_output = splitted_data[6].strip() - # filter out too short or too long instructions - if len(inst.split()) <= 3 or len(inst.split()) > 150: - writeline2file( - discarded_file, - "Discarded instruction(wrong number of words): " + repr(splitted_data), - ) - discarded += 1 - continue - # filter based on keywords that are not suitable for language models. - if any(find_word_in_string(word, inst) for word in _WORD_DENYLIST): - writeline2file( - discarded_file, - "Discarded instruction(contained a word from the denylist): " - + repr(splitted_data), - ) - discarded += 1 - continue - # We found that the model tends to add "write a program" to some existing instructions - # which lead to a lot of such instructions and it's confusing whether the model needs - # to write a program or directly output the result, so here we filter them out. - # NOTE: this is not a comprehensive filtering for all programming instructions. - if inst.startswith("Write a program"): - writeline2file( - discarded_file, - "Discarded instruction(began with 'Write a program'): " - + repr(splitted_data), - ) - discarded += 1 - continue - # filter those starting with punctuation - if inst[0] in string.punctuation: - writeline2file( - discarded_file, - "Discarded instruction(began with punctuation): " + repr(splitted_data), - ) - discarded += 1 - continue - # filter those starting with non-english character - if not inst[0].isascii(): - writeline2file( - discarded_file, - "Discarded instruction(began with non-ascii): " + repr(splitted_data), - ) - discarded += 1 - continue - instructions.append( - {"instruction": inst, "input": prompt_input, "output": prompt_output} + if "output" not in synth_example: + raise utils.GenerateException( + f"Error: output not found in synth_example: {synth_example}" ) - return instructions, discarded + parts = synth_example["output"].split("?", 1) + if len(parts) != 2: + logger.warning(f"Failed to split generated q&a: {synth_example['output']}") + return parts[0].strip() + "?" if len(parts) == 2 else "" -def find_word_in_string(w, s): - return re.compile(r"\b({0})\b".format(w), flags=re.IGNORECASE).search(s) +# This is also a hack. See the comment above _get_question. +def _get_response(logger, synth_example): + if "response" in synth_example: + return synth_example["response"] -def get_instructions_from_model( - logger, - request_idx, - instruction_data_pool, - prompt_template, - api_base, - api_key, - model_name, - num_prompt_instructions, - request_batch_size, - temperature, - top_p, - output_file_discarded, - tls_insecure, - tls_client_cert, - tls_client_key, - tls_client_passwd, -): - batch_inputs = [] - for _ in range(request_batch_size): - # only sampling from the seed tasks - try: - prompt_instructions = random.sample( - instruction_data_pool, num_prompt_instructions - ) - except ValueError as exc: - raise utils.GenerateException( - f"There was a problem with the new data, please make sure the " - f"yaml is formatted correctly, and there is enough " - f"new data({num_prompt_instructions}+ Q&A))" - ) from exc - prompt = encode_prompt(prompt_instructions, prompt_template) - batch_inputs.append(prompt) - decoding_args = openai.OpenAIDecodingArguments( - temperature=temperature, - n=1, - # Hard-coded to maximize length. - # Requests will be automatically adjusted. - max_tokens=3072, - top_p=top_p, - stop=["* Task 5"], - ) - request_start = time.time() - try: - results = openai.openai_completion( - api_base=api_base, - api_key=api_key, - prompts=batch_inputs, - model_name=model_name, - tls_insecure=tls_insecure, - tls_client_cert=tls_client_cert, - tls_client_key=tls_client_key, - tls_client_passwd=tls_client_passwd, - batch_size=request_batch_size, - decoding_args=decoding_args, - ) - except utils.GenerateException as exc: - # Attempt to log and gracefully recover from exceeding the server's - # maximum context length. This won't work for all servers. - # - # Both llama_cpp_python and vllm use this exact string in their error - # responses when exceeding the model's max content length. Other - # OpenAI-compatible servers may as well, but no guarantees. - if "model's maximum context length" in str(exc): - logger.warn( - "Generated prompt exceeded the server's maximum context length. " - "If you see this warning many times during generation, lower " - "the length of your example question and answers or raise the " - "server's maximum context size using `max_ctx_size`." - ) - return [], 0 - raise exc - request_duration = time.time() - request_start - - post_process_start = time.time() - instruction_data = [] - for result in results: - new_instructions, discarded = post_process_gpt3_response( - num_prompt_instructions, result, output_file_discarded + if "output" not in synth_example: + raise utils.GenerateException( + f"Error: output not found in synth_example: {synth_example}" ) - # make sure the generated instruction carried over extra fields - prompt_ins_0 = prompt_instructions[0] - for new_ins in new_instructions: - new_ins["taxonomy_path"] = prompt_ins_0["taxonomy_path"] - new_ins["task_description"] = prompt_ins_0["task_description"] - new_ins["document"] = prompt_ins_0["document"] - instruction_data += new_instructions - - post_process_duration = time.time() - post_process_start - logger.debug( - f"Request {request_idx} took {request_duration:.2f}s, " - f"post-processing took {post_process_duration:.2f}s" - ) - return instruction_data, discarded + parts = synth_example["output"].split("?", 1) + if len(parts) != 2: + logger.warning(f"Failed to split generated q&a: {synth_example['output']}") + return parts[1].strip() if len(parts) == 2 else parts[0].strip() -def unescape(s): - return bytes(s, "utf-8").decode("utf-8") +def _gen_train_data(logger, machine_instruction_data, output_file_train): + train_data = [] + for synth_example in machine_instruction_data: + logger.debug(synth_example) + user = _get_question(logger, synth_example) + if len(synth_example.get("context", "")) > 0: + user += "\n" + synth_example["context"] + train_data.append( + { + "system": get_sysprompt(), + "user": _unescape(user), + "assistant": _unescape(_get_response(logger, synth_example)), + } + ) + + with open(output_file_train, "w", encoding="utf-8") as outfile: + for entry in train_data: + json.dump(entry, outfile, ensure_ascii=False) + outfile.write("\n") def _gen_test_data( - logger, - seed_instruction_data, - max_seed_tokens, - taxonomy, - chunk_word_count, - server_ctx_size, + leaf_nodes, output_file_test, ): - max_seed_chars = chunking.num_chars_from_tokens(max_seed_tokens) - for seed_example in seed_instruction_data: - if ( - len(seed_example["instruction"]) - + len(seed_example["input"]) - + len(seed_example["output"]) - >= max_seed_chars - ): - raise SystemExit( - f"Error: An example in the taxonomy path {seed_example['taxonomy_path']} is too long for the server context size of {server_ctx_size}. Ensure the total number of characters across the combined question, answer, and context is less than {max_seed_chars} for each example or use a server with a larger context size." - ) - - seeds = len(seed_instruction_data) - logger.debug(f"Loaded {seeds} human-written seed instructions from {taxonomy}") - if not seeds: - raise SystemExit("Nothing to generate. Exiting.") - test_data = [] - for seed_example in seed_instruction_data: - user = seed_example["instruction"] - - documents = seed_example["document"] - if documents: - seed_example["document"] = chunking.chunk_document( - documents=documents, - server_ctx_size=server_ctx_size, - chunk_word_count=chunk_word_count, - ) + for _, leaf_node in leaf_nodes.items(): + for seed_example in leaf_node: + user = seed_example["instruction"] # question + + if len(seed_example["input"]) > 0: + user += "\n" + seed_example["input"] # context - if len(seed_example["input"]) > 0: - user += "\n" + seed_example["input"] - try: test_data.append( { - "system": instructlab.utils.get_sysprompt(), - "user": unescape(user), - "assistant": unescape(seed_example["output"]), + "system": get_sysprompt(), + "user": _unescape(user), + "assistant": _unescape(seed_example["output"]), # answer } ) - except TypeError as exc: - click.secho( - f"Error reading seed examples: {exc}. Please make sure your answers are verbose enough.", - fg="red", - ) - raise click.exceptions.Exit(1) - # json_utils.jdump(test_data, os.path.join(output_dir, output_file_test)) + with open(output_file_test, "w", encoding="utf-8") as outfile: for entry in test_data: json.dump(entry, outfile, ensure_ascii=False) outfile.write("\n") -def _gen_train_data(machine_instruction_data, output_file_train): - train_data = [] - for synth_example in machine_instruction_data: - user = synth_example["instruction"] - if len(synth_example["input"]) > 0: - user += "\n" + synth_example["input"] - train_data.append( - { - "system": instructlab.utils.get_sysprompt(), - "user": unescape(user), - "assistant": unescape(synth_example["output"]), - } - ) - # json_utils.jdump(train_data, output_file_train) - with open(output_file_train, "w", encoding="utf-8") as outfile: - for entry in train_data: - json.dump(entry, outfile, ensure_ascii=False) - outfile.write("\n") +def _sdg_init(profile, client, model_family, model_name, batched): + knowledge_flow_types = [] + freeform_skill_flow_types = [] + grounded_skill_flow_types = [] + if profile == "full": + knowledge_flow_types.append(MMLUBenchFlow) + knowledge_flow_types.append(SynthKnowledgeFlow) + elif profile == "simple": + knowledge_flow_types.append(SimpleKnowledgeFlow) + freeform_skill_flow_types.append(SimpleFreeformSkillFlow) + grounded_skill_flow_types.append(SimpleGroundedSkillFlow) + else: + raise utils.GenerateException(f"Error: profile ({profile}) is not supported.") + + sdg_knowledge = SDG( + [ + Pipeline(flow_type(client, model_family, model_name, batched).get_flow()) + for flow_type in knowledge_flow_types + ] + ) + sdg_freeform_skill = SDG( + [ + Pipeline(flow_type(client, model_family, model_name, batched).get_flow()) + for flow_type in freeform_skill_flow_types + ] + ) + sdg_grounded_skill = SDG( + [ + Pipeline(flow_type(client, model_family, model_name, batched).get_flow()) + for flow_type in grounded_skill_flow_types + ] + ) + return sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill +# TODO - parameter removal needs to be done in sync with a CLI change. +# pylint: disable=unused-argument def generate_data( logger, api_base, @@ -452,14 +168,23 @@ def generate_data( output_dir: Optional[str] = None, taxonomy: Optional[str] = None, taxonomy_base: Optional[str] = None, + # TODO - not used and should be removed from the CLI prompt_file_path: Optional[str] = None, model_name: Optional[str] = None, + # TODO - not used -- when batching is enabled, this is relevant. + # Right now the code hard codes 8 cpus for batching num_cpus: Optional[int] = None, + # TODO - not yet used, but should be presumably num_instructions_to_generate: Optional[int] = None, + # TODO - not used, can probably be removed num_prompt_instructions=2, + # TODO - determine if this is relevant request_batch_size=5, + # TODO - probably should be removed temperature=1.0, + # TODO - probably should be removed top_p=1.0, + # TODO - probably should be removed rouge_threshold: Optional[float] = None, console_output=True, api_key: Optional[str] = None, @@ -468,181 +193,120 @@ def generate_data( tls_client_cert: Optional[str] = None, tls_client_key: Optional[str] = None, tls_client_passwd: Optional[str] = None, + # TODO need to update the CLI to specify which profile to use (simple or full at the moment) + profile: Optional[str] = "simple", ): - seed_instruction_data = [] - machine_seed_instruction_data = [] generate_start = time.time() if not os.path.exists(output_dir): os.mkdir(output_dir) - # check taxonomy first then seed_tasks_path - # throw an error if both not found - # pylint: disable=broad-exception-caught,raise-missing-from - if taxonomy and os.path.exists(taxonomy): - seed_instruction_data = taxonomy_utils.read_taxonomy( - taxonomy, taxonomy_base, yaml_rules - ) - else: - raise SystemExit(f"Error: taxonomy ({taxonomy}) does not exist.") + if not (taxonomy and os.path.exists(taxonomy)): + raise utils.GenerateException(f"Error: taxonomy ({taxonomy}) does not exist.") - prompt_template = check_prompt_file( - prompt_file_path, models.get_model_family(model_family, model_name) - ) - max_seed_tokens = chunking.max_seed_example_tokens( - server_ctx_size, len(prompt_template) - ) + leaf_nodes = read_taxonomy_leaf_nodes(taxonomy, taxonomy_base, yaml_rules) + if not leaf_nodes: + raise utils.GenerateException("Error: No new leaf nodes found in the taxonomy.") name = Path(model_name).stem # Just in case it is a file path date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_") - output_file = f"generated_{name}_{date_suffix}.json" - output_file_train = f"train_{name}_{date_suffix}.jsonl" + output_file_generated = f"generated_{name}_{date_suffix}.json" output_file_test = f"test_{name}_{date_suffix}.jsonl" - output_file_discarded = os.path.join( - output_dir, f"discarded_{name}_{date_suffix}.log" - ) + output_file_train = f"train_{name}_{date_suffix}.jsonl" + _gen_test_data( - logger, - seed_instruction_data, - max_seed_tokens, - taxonomy, - chunk_word_count, - server_ctx_size, + leaf_nodes, os.path.join(output_dir, output_file_test), ) - logger.debug(f"Generating to: {os.path.join(output_dir, output_file)}") - - request_idx = 0 - # load the LM-generated instructions - machine_instruction_data = [] - if os.path.exists(os.path.join(output_dir, "regen.json")): - machine_instruction_data = json_utils.jload( - os.path.join(output_dir, "regen.json") - ) - logger.debug( - f"Loaded {len(machine_instruction_data)} machine-generated instructions" - ) - # similarities = {} - # Calculate rouges scores between two blobs of text. - # rougeL: Longest common subsequence based scoring. - # https://github.com/google-research/google-research/blob/master/rouge/rouge_scorer.py#L50 - scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) - - # now let's generate new instructions! - progress_bar = tqdm.tqdm(total=num_instructions_to_generate) - if machine_instruction_data: - progress_bar.update(len(machine_instruction_data)) - - # first we tokenize all the seed instructions and generated machine instructions - all_instructions = [d["instruction"] for d in seed_instruction_data] + [ - d["instruction"] for d in machine_instruction_data - ] - all_instruction_tokens = [ - # https://github.com/google-research/google-research/blob/master/rouge/tokenize.py - scorer._tokenizer.tokenize(inst) - for inst in all_instructions - ] + logger.debug(f"Generating to: {os.path.join(output_dir, output_file_generated)}") + + orig_cert = (tls_client_cert, tls_client_key, tls_client_passwd) + cert = tuple(item for item in orig_cert if item) + verify = not tls_insecure + client = openai.OpenAI( + base_url=api_base, + api_key=api_key, + http_client=httpx.Client(cert=cert, verify=verify), + ) + + if models.get_model_family(model_family, model_name) == "mixtral": + model_family = MODEL_FAMILY_MIXTRAL + else: + model_family = MODEL_FAMILY_MERLINITE + + # TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI + # about whether we can turn this on (whether vllm is used or not) + batched = False + + sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init( + profile, client, model_family, model_name, batched + ) if console_output: - print( + logger.info( "Synthesizing new instructions. If you aren't satisfied with the generated instructions, interrupt training (Ctrl-C) and try adjusting your YAML files. Adding more examples may help." ) - mpctx = multiprocessing.get_context(None) - - all_taxonomy_paths = list(set(e["taxonomy_path"] for e in seed_instruction_data)) - total_discarded = 0 - total_rouged = 0 - while len(machine_instruction_data) < num_instructions_to_generate: - request_idx += 1 - - # Pick taxonomy path - selected_taxonomy = all_taxonomy_paths[request_idx % len(all_taxonomy_paths)] - logger.info(f"Selected taxonomy path {selected_taxonomy}") - # Filter the pool - instruction_data_pool = [ - e - for e in seed_instruction_data + machine_seed_instruction_data - if e["taxonomy_path"] == selected_taxonomy - ] - instruction_data, discarded = get_instructions_from_model( - logger, - request_idx, - instruction_data_pool, - prompt_template, - api_base, - api_key, - model_name, - num_prompt_instructions, - request_batch_size, - temperature, - top_p, - output_file_discarded, - tls_insecure, - tls_client_cert, - tls_client_key, - tls_client_passwd, - ) - total_discarded += discarded - total = len(instruction_data) - keep = 0 - assess_start = time.time() - for instruction_data_entry in instruction_data: - # computing similarity with the pre-tokenized instructions - # https://github.com/google-research/google-research/blob/master/rouge/tokenize.py - new_instruction_tokens = scorer._tokenizer.tokenize( - instruction_data_entry["instruction"] + generated_data = None + for leaf_node in leaf_nodes.values(): + samples = leaf_node_to_samples(leaf_node) + + if not samples: + raise utils.GenerateException("Error: No samples found in leaf node.") + + sdg = None + if samples[0].get("document"): + sdg = sdg_knowledge + elif samples[0].get("context"): + sdg = sdg_grounded_skill + else: + sdg = sdg_freeform_skill + + if not sdg: + # TODO - can be removed once the "full" pipelines are all defined, + # as there shouldn't be a code path to get here anymore + raise utils.GenerateException( + "Error: No SDG pipeline for this leaf node type: %s" % samples[0] ) - with mpctx.Pool(num_cpus) as pool: - rouge_scores = pool.map( - partial(rouge_scorer._score_lcs, new_instruction_tokens), - all_instruction_tokens, - ) - pool.join() - instruction_data_entry["taxonomy_path"] = selected_taxonomy - rouge_scores = [score.fmeasure for score in rouge_scores] - # Comment out extra info not currently being used: - # most_similar_instructions = { - # all_instructions[i]: rouge_scores[i] for i in np.argsort(rouge_scores)[-10:][::-1] - # } - if max(rouge_scores) > rouge_threshold: - total_rouged += 1 - continue - keep += 1 - # Comment out extra info not currently being used: - # instruction_data_entry["most_similar_instructions"] = most_similar_instructions - # instruction_data_entry["avg_similarity_score"] = float(np.mean(rouge_scores)) - - # Only add sufficiently small instructions to our machine seeds - if len(new_instruction_tokens) <= max_seed_tokens: - machine_seed_instruction_data.append(instruction_data_entry) - - machine_instruction_data.append(instruction_data_entry) - all_instructions.append(instruction_data_entry["instruction"]) - all_instruction_tokens.append(new_instruction_tokens) - if console_output: - print( - f"Q> {instruction_data_entry['instruction']}\nI> {instruction_data_entry['input']}\nA> {instruction_data_entry['output']}\n" - ) - progress_bar.update(keep) - assess_duration = time.time() - assess_start - logger.debug(f"Assessing generated samples took {assess_duration:.2f}s") - logger.debug( - f"Generated {total} instructions(discarded {discarded}), rouged {total - keep}, kept {keep} instructions" - ) - json_utils.jdump( - machine_instruction_data, os.path.join(output_dir, output_file) - ) - _gen_train_data( - machine_instruction_data, os.path.join(output_dir, output_file_train) + + # TODO this is broken, just trying to get initial integration to run + # pylint: disable=consider-using-enumerate + if samples[0].get("document"): + for i in range(len(samples)): + samples[i]["document"] = chunking.chunk_document( + documents=samples[i]["document"], + server_ctx_size=server_ctx_size, + chunk_word_count=chunk_word_count, + )[0] + + # TODO -- there is a parameter for how many samples to generate, but we ignore it so far + + logger.debug("Samples: %s" % samples) + ds = Dataset.from_list(samples) + logger.debug("Dataset: %s" % ds) + new_generated_data = sdg.generate(ds) + generated_data = ( + new_generated_data + if generated_data is None + else generated_data + new_generated_data ) + logger.info("Generated %d samples" % len(generated_data)) + logger.debug("Generated data: %s" % generated_data) - progress_bar.close() + if generated_data is None: + generated_data = [] + + _gen_train_data(logger, generated_data, os.path.join(output_dir, output_file_train)) + + # TODO + # This is for backwards compatibility. The file existing previously, so we'll keep it for now. + # I believe the github bot assumes it is present for presenting generated data to a taxonomy + # reviewer or contributor. Otherwise, I don't see a consumer of it in this repo or the + # `ilab` CLI. + _gen_train_data( + logger, generated_data, os.path.join(output_dir, output_file_generated) + ) - if total_discarded or total_rouged: - logger.info( - f"{len(machine_instruction_data)} instructions generated, {total_discarded} discarded due to format (see {output_file_discarded}), {total_rouged} discarded due to rouge score" - ) generate_duration = time.time() - generate_start logger.info(f"Generation took {generate_duration:.2f}s") diff --git a/src/instructlab/sdg/utils/chunking.py b/src/instructlab/sdg/utils/chunking.py index 281410ee..79f1e16c 100644 --- a/src/instructlab/sdg/utils/chunking.py +++ b/src/instructlab/sdg/utils/chunking.py @@ -13,51 +13,10 @@ def _num_tokens_from_words(num_words) -> int: return int(num_words * 1.3) # 1 word ~ 1.3 token -def num_chars_from_tokens(num_tokens) -> int: +def _num_chars_from_tokens(num_tokens) -> int: return int(num_tokens * 4) # 1 token ~ 4 English character -def _num_tokens_from_chars(num_chars) -> int: - return int(num_chars / 4) # 1 token ~ 4 English character - - -def max_seed_example_tokens(server_ctx_size, prompt_num_chars) -> int: - """ - Estimates the maximum number of tokens any seed example can have based - on the server context size and number of characters in the selected prompt. - - A lot has to fit into the given server context size: - - The prompt itself, which can vary in size a bit based on model family and knowledge vs skill - - Two seed examples, which we append to the prompt template. - - A knowledge document chunk, if this is a knowledge example. - - The generated completion, which can vary substantially in length. - - This is an attempt to roughly estimate the maximum size any seed example - (question + answer + context values from the yaml) should be to even have - a hope of not often exceeding the server's maximum context size. - - NOTE: This does not take into account knowledge document chunks. It's meant - to calculate the maximum size that any seed example should be, whether knowledge - or skill. Knowledge seed examples will want to stay well below this limit. - - NOTE: This is a very simplistic calculation, and examples with lots of numbers - or punctuation may have quite a different token count than the estimates here, - depending on the model (and thus tokenizer) in use. That's ok, as it's only - meant to be a rough estimate. - - Args: - server_ctx_size (int): Size of the server context, in tokens. - prompt_num_chars (int): Number of characters in the prompt (not including the examples) - """ - # Ensure we have at least 1024 tokens available for a response. - max_seed_tokens = server_ctx_size - 1024 - # Subtract the number of tokens in our prompt template - max_seed_tokens = max_seed_tokens - _num_tokens_from_chars(prompt_num_chars) - # Divide number of characters by 2, since we insert 2 examples - max_seed_tokens = int(max_seed_tokens / 2) - return max_seed_tokens - - def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[str]: """ Iterates over the documents and splits them into chunks based on the word count provided by the user. @@ -80,7 +39,7 @@ def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[s content = [] text_splitter = RecursiveCharacterTextSplitter( separators=["\n\n", "\n", " "], - chunk_size=num_chars_from_tokens(no_tokens_per_doc), + chunk_size=_num_chars_from_tokens(no_tokens_per_doc), chunk_overlap=_DEFAULT_CHUNK_OVERLAP, ) diff --git a/src/instructlab/sdg/utils/openai.py b/src/instructlab/sdg/utils/openai.py deleted file mode 100644 index f11ef4f3..00000000 --- a/src/instructlab/sdg/utils/openai.py +++ /dev/null @@ -1,175 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Standard -from typing import Optional, Sequence, Union -import copy -import dataclasses -import logging -import math -import sys - -# Third Party -# instructlab - TODO these need to go away, issue #6 -from instructlab.configuration import DEFAULT_API_KEY, DEFAULT_MODEL_OLD -from instructlab.utils import get_sysprompt -from openai import OpenAI, OpenAIError -import httpx - -# Local -from . import GenerateException - -StrOrOpenAIObject = Union[str, object] - - -# pylint: disable=too-many-instance-attributes -@dataclasses.dataclass -class OpenAIDecodingArguments: - max_tokens: int = 1800 - temperature: float = 0.2 - top_p: float = 1.0 - n: int = 1 - stream: bool = False - stop: Optional[Sequence[str]] = None - presence_penalty: float = 0.0 - frequency_penalty: float = 0.0 - logprobs: Optional[int] = None - - -def openai_completion( - api_base, - tls_insecure, - tls_client_cert, - tls_client_key, - tls_client_passwd, - prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]], - decoding_args: OpenAIDecodingArguments, - model_name="ggml-merlinite-7b-lab-Q4_K_M", - batch_size=1, - max_instances=sys.maxsize, - max_batches=sys.maxsize, - return_text=False, - api_key=DEFAULT_API_KEY, - **decoding_kwargs, -) -> Union[ - Union[StrOrOpenAIObject], - Sequence[StrOrOpenAIObject], - Sequence[Sequence[StrOrOpenAIObject]], -]: - """Decode with OpenAI API. - - Args: - api_base: Endpoint URL where model is hosted - tls_insecure: Disable TLS verification - tls_client_cert: Path to the TLS client certificate to use - tls_client_key: Path to the TLS client key to use - tls_client_passwd: TLS client certificate password - prompts: A string or a list of strings to complete. If it is a chat model the strings - should be formatted as explained here: - https://github.com/openai/openai-python/blob/main/chatml.md. - If it is a chat model it can also be a dictionary (or list thereof) as explained here: - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb - decoding_args: Decoding arguments. - model_name: Model name. Can be either in the format of "org/model" or just "model". - batch_size: Number of prompts to send in a single request. Only for non chat model. - max_instances: Maximum number of prompts to decode. - max_batches: Maximum number of batches to decode. This will be deprecated in the future. - return_text: If True, return text instead of full completion object (e.g. includes logprob). - api_key: API key API key for API endpoint where model is hosted - decoding_kwargs: Extra decoding arguments. Pass in `best_of` and `logit_bias` if needed. - - Returns: - A completion or a list of completions. Depending on return_text, return_openai_object, - and decoding_args.n, the completion type can be one of: - - a string (if return_text is True) - - an openai_object.OpenAIObject object (if return_text is False) - - a list of objects of the above types (if decoding_args.n > 1) - """ - is_single_prompt = isinstance(prompts, (str, dict)) - if is_single_prompt: - prompts = [prompts] - - if max_batches < sys.maxsize: - logging.warning( - "`max_batches` will be deprecated in the future, please use `max_instances` instead." - "Setting `max_instances` to `max_batches * batch_size` for now." - ) - max_instances = max_batches * batch_size - - prompts = prompts[:max_instances] - num_prompts = len(prompts) - prompt_batches = [ - prompts[batch_id * batch_size : (batch_id + 1) * batch_size] - for batch_id in range(int(math.ceil(num_prompts / batch_size))) - ] - - completions = [] - for batch_id, prompt_batch in enumerate(prompt_batches): - batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args - - shared_kwargs = { - "model": model_name, - **batch_decoding_args.__dict__, - **decoding_kwargs, - } - - if not api_key: - # we need to explicitly set non-empty api-key, to ensure generate - # connects to our local server - api_key = "no_api_key" - - # do not pass a lower timeout to this client since generating a dataset takes some time - # pylint: disable=R0801 - orig_cert = (tls_client_cert, tls_client_key, tls_client_passwd) - cert = tuple(item for item in orig_cert if item) - verify = not tls_insecure - client = OpenAI( - base_url=api_base, - api_key=api_key, - http_client=httpx.Client(cert=cert, verify=verify), - ) - - # ensure the model specified exists on the server. with backends like vllm, this is crucial. - model_list = client.models.list().data - model_ids = [] - for model in model_list: - model_ids.append(model.id) - if not any(model_name == m for m in model_ids): - if model_name == DEFAULT_MODEL_OLD: - logging.info( - "Model %s is not a full path. Try running ilab init or edit your config to have the full model path for serving, chatting, and generation.", - model_name, - ) - raise GenerateException( - f"Model {model_name} is not served by the server. These are the served models {model_ids}" - ) - - messages = [ - {"role": "system", "content": get_sysprompt()}, - {"role": "user", "content": prompt_batch[batch_id]}, - ] - - # Inference the model - try: - response = client.chat.completions.create( - messages=messages, - **shared_kwargs, - ) - except OpenAIError as exc: - raise GenerateException( - f"There was a problem connecting to the server {exc}" - ) from exc - - completions.extend(response.choices) - - if return_text: - completions = [completion.text for completion in completions] - if decoding_args.n > 1: - # make a nested list, where each entry is consecutive decoding_args.n of original entries. - completions = [ - completions[i : i + decoding_args.n] - for i in range(0, len(completions), decoding_args.n) - ] - if is_single_prompt: - # Return non-tuple if only 1 input and 1 generation. - (completions,) = completions - return completions diff --git a/src/instructlab/sdg/utils/taxonomy.py b/src/instructlab/sdg/utils/taxonomy.py index 63fad185..9e62baa5 100644 --- a/src/instructlab/sdg/utils/taxonomy.py +++ b/src/instructlab/sdg/utils/taxonomy.py @@ -17,6 +17,9 @@ import gitdb import yaml +# First Party +from instructlab.sdg import utils + logger = logging.getLogger(__name__) DEFAULT_YAML_RULES = """\ @@ -331,6 +334,7 @@ def _read_taxonomy_file(file_path: str, yaml_rules: Optional[str] = None): # get seed instruction data tax_path = "->".join(taxonomy_path.parent.parts) task_description = contents.get("task_description") + domain = contents.get("domain") documents = contents.get("document") if documents: documents = _get_documents(source=documents) @@ -348,6 +352,7 @@ def _read_taxonomy_file(file_path: str, yaml_rules: Optional[str] = None): "taxonomy_path": tax_path, "task_description": task_description, "document": documents, + "domain": domain, } ) except Exception as e: @@ -395,3 +400,57 @@ def read_taxonomy(taxonomy, taxonomy_base, yaml_rules): yaml.YAMLError(f"{total_errors} taxonomy files with errors! Exiting.") ) return seed_instruction_data + + +def read_taxonomy_leaf_nodes(taxonomy, taxonomy_base, yaml_rules): + seed_instruction_data = read_taxonomy(taxonomy, taxonomy_base, yaml_rules) + + # Transform into a more convenient format to feed into our updated SDG library + leaf_nodes = {} + for seed in seed_instruction_data: + node = leaf_nodes.setdefault(seed["taxonomy_path"], []) + node.append(seed) + leaf_nodes[seed["taxonomy_path"]] = node + + return leaf_nodes + + +def leaf_node_to_samples(leaf_node): + samples = [{}] + + # pylint: disable=consider-using-enumerate + for i in range(len(leaf_node)): + samples[-1].setdefault("task_description", leaf_node[i]["task_description"]) + for field in ["document", "domain"]: + if leaf_node[i].get(field): + samples[-1].setdefault(field, leaf_node[i][field]) + if samples[-1].get("document") and not samples[-1].get("domain"): + raise utils.GenerateException( + "Error: No domain provided for knowledge document in leaf node" + ) + if leaf_node[i].get("input"): + samples[-1].setdefault("context", leaf_node[i]["input"]) + if "question_3" in samples[-1]: + samples.append({}) + if "question_1" not in samples[-1]: + samples[-1]["question_1"] = leaf_node[i]["instruction"] + samples[-1]["response_1"] = leaf_node[i]["output"] + elif "question_2" not in samples[-1]: + samples[-1]["question_2"] = leaf_node[i]["instruction"] + samples[-1]["response_2"] = leaf_node[i]["output"] + else: + samples[-1]["question_3"] = leaf_node[i]["instruction"] + samples[-1]["response_3"] = leaf_node[i]["output"] + + # wrap back around to the beginning if the number of examples was not + # evenly divisble by 3 + if "question_2" not in samples[-1]: + samples[-1]["question_2"] = leaf_node[0]["instruction"] + samples[-1]["response_2"] = leaf_node[0]["output"] + if "question_3" not in samples[-1]: + samples[-1]["question_3"] = leaf_node[1 if len(leaf_node) > 1 else 0][ + "instruction" + ] + samples[-1]["response_3"] = leaf_node[1 if len(leaf_node) > 1 else 0]["output"] + + return samples diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 6d23b2be..688e61e4 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -45,7 +45,7 @@ def test_chunk_docs_long_lines(self): server_ctx_size=4096, ) max_tokens = chunking._num_tokens_from_words(chunk_words) - max_chars = chunking.num_chars_from_tokens(max_tokens) + max_chars = chunking._num_chars_from_tokens(max_tokens) max_chars += chunking._DEFAULT_CHUNK_OVERLAP # add in the chunk overlap max_chars += 50 # and a bit extra for some really long words for chunk in chunks: