Skip to content

Commit

Permalink
Add simple pipeline for use with default merlinite
Browse files Browse the repository at this point in the history
The CLI's default model is quantized merlinite, and it does not seem
good enough to follow the instructions in the full pipeline included
in the new library.

This change attempts to still use the new library, but with a very
minimal configuration. It's not doing any validation on the output, so
the output is not going to be great.  Then again, the output has never
been great doing SDG with merlinite and the old sdg implementation.
This at least keeps the ability to a basic workflow test and demo on a
smaller system.

Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb committed Jun 26, 2024
1 parent 626b5ae commit 7edbc76
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 31 deletions.
37 changes: 37 additions & 0 deletions src/instructlab/sdg/configs/knowledge/simple_generate_qa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
system: You are a very knowledgeable AI Assistant that will faithfully assist the user with their task.

introduction: Develop a series of educational question and answer pairs from a chapter in a {domain} textbook.

principles: |
Here are the requirements:
1. Try not to repeat the verb for each instruction to maximize diversity.
2. The language used for the instruction also should be diverse. For example, you should combine questions with imperative instructions.
3. The type of instructions should be similar to provided examples. The generated instruction and the output should be grounded in the provided document.
4. A GPT language model should be able to complete the instruction. For example, do not ask the assistant to create any visual or audio output. For another example, do not ask the assistant to wake you up at 5pm or set a reminder because it cannot perform any action.
5. The instructions should be in English.
6. The instructions should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted.
7. The output should be an appropriate response to the input and the instruction. Long outputs are preferable.

examples: |
Here are some examples to help you understand the type of questions that are asked for this document:
{question_1}
{response_1}
{question_2}
{response_2}
{question_3}
{response_3}
Here is the document:
{document}
generation: |
Provide a single question and answer pair based on the document:
Document:
{{document}}
start_tags: [""]
end_tags: [""]
32 changes: 30 additions & 2 deletions src/instructlab/sdg/default_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import os

# First Party
import instructlab.sdg.utils as utils
from instructlab.sdg import utils

# Local
from .filterblock import FilterByValueBlock
from .iterblock import IterBlock
from .llmblock import LLMBlock


MODEL_PROMPT_MIXTRAL = "<s> [INST] {prompt} [/INST]"
MODEL_PROMPT_MERLINITE = "'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n'"

Expand All @@ -25,6 +24,7 @@ def _get_model_prompt(model_id):
else MODEL_PROMPT_MERLINITE
)


class Flow(ABC):
def __init__(self, client, model_id) -> None:
self.client = client
Expand All @@ -35,6 +35,34 @@ def get_flow(self) -> list:
pass


class SimpleKnowledgeFlow(Flow):
def get_flow(self) -> list:
sdg_base = resources.files(__package__)
return [
{
"block_type": LLMBlock,
"block_config": {
"block_name": "gen_knowledge",
"config_path": os.path.join(
sdg_base, "configs/knowledge/simple_generate_qa.yaml"
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"output_cols": ["output"],
"batch_kwargs": {
"num_procs": 8,
"batched": True,
},
},
"gen_kwargs": {
"max_tokens": 2048,
},
"drop_duplicates": ["output"],
},
]


class MMLUBenchFlow(Flow):
def get_flow(self) -> list:
sdg_base = resources.files(__package__)
Expand Down
103 changes: 80 additions & 23 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
# First Party
# pylint: disable=ungrouped-imports
from instructlab.sdg import SDG, utils
from instructlab.sdg.default_flows import MMLUBenchFlow, SynthKnowledgeFlow
from instructlab.sdg.default_flows import (
MMLUBenchFlow,
SimpleKnowledgeFlow,
SynthKnowledgeFlow,
)
from instructlab.sdg.pipeline import Pipeline

_WORD_DENYLIST = [
Expand Down Expand Up @@ -62,11 +66,12 @@ def unescape(s):
return bytes(s, "utf-8").decode("utf-8")


def _gen_train_data(machine_instruction_data, output_file_train):
def _gen_train_data(logger, machine_instruction_data, output_file_train):
train_data = []
for synth_example in machine_instruction_data:
user = synth_example["instruction"]
if len(synth_example["input"]) > 0:
logger.debug(synth_example)
user = synth_example.get("instruction", "")
if len(synth_example.get("input", "")) > 0:
user += "\n" + synth_example["input"]
train_data.append(
{
Expand All @@ -82,23 +87,36 @@ def _gen_train_data(machine_instruction_data, output_file_train):
outfile.write("\n")


# TODO - parameter removal needs to be done in sync with a CLI change.
# pylint: disable=unused-argument
def generate_data(
logger,
api_base,
tls_insecure,
# TODO - not yet used. Right now the lib will guess based on the model name
# but we should pass this along if specified
model_family: str,
yaml_rules: Optional[str] = None,
output_dir: Optional[str] = None,
taxonomy: Optional[str] = None,
taxonomy_base: Optional[str] = None,
# TODO - not used and should be removed from the CLI
prompt_file_path: Optional[str] = None,
model_name: Optional[str] = None,
# TODO - not used -- when batching is enabled, this is relevant.
# Right now the code hard codes 8 cpus for batching
num_cpus: Optional[int] = None,
# TODO - not yet used, but should be presumably
num_instructions_to_generate: Optional[int] = None,
# TODO - not used, can probably be removed
num_prompt_instructions=2,
# TODO - determine if this is relevant
request_batch_size=5,
# TODO - probably should be removed
temperature=1.0,
# TODO - probably should be removed
top_p=1.0,
# TODO - probably should be removed
rouge_threshold: Optional[float] = None,
console_output=True,
api_key: Optional[str] = None,
Expand All @@ -107,6 +125,8 @@ def generate_data(
tls_client_cert: Optional[str] = None,
tls_client_key: Optional[str] = None,
tls_client_passwd: Optional[str] = None,
# TODO need to update the CLI to specify which profile to use (simple or full at the moment)
profile: Optional[str] = "simple",
):
seed_instruction_data = []
generate_start = time.time()
Expand Down Expand Up @@ -147,25 +167,56 @@ def generate_data(
http_client=httpx.Client(cert=cert, verify=verify),
)

mmlu_flow = MMLUBenchFlow(client, model_name).get_flow()
# TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI
# about whether we can turn this on (whether vllm is used or not)
for i, _ in enumerate(mmlu_flow):
if "block_config" in mmlu_flow[i] and "batch_kwargs" in mmlu_flow[i]["block_config"]:
mmlu_flow[i]["block_config"]["batch_kwargs"]["batched"] = False
logger.debug(mmlu_flow[i])
else:
logger.debug("No batch_kwargs in mmlu_flow: %s" % mmlu_flow[i])
knowledge_flow = SynthKnowledgeFlow(client, model_name).get_flow()
for i, _ in enumerate(knowledge_flow):
if "block_config" in knowledge_flow[i] and "batch_kwargs" in knowledge_flow[i]["block_config"]:
knowledge_flow[i]["block_config"]["batch_kwargs"]["batched"] = False
logger.debug(knowledge_flow[i])
else:
logger.debug("No batch_kwargs in knowledge_flow: %s" % knowledge_flow[i])
knowledge_pipe = Pipeline(knowledge_flow)
mmlu_pipe = Pipeline(mmlu_flow)
sdg = SDG([mmlu_pipe, knowledge_pipe])
sdg = None
if profile == "full":
mmlu_flow = MMLUBenchFlow(client, model_name).get_flow()
# TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI
# about whether we can turn this on (whether vllm is used or not)
for i, _ in enumerate(mmlu_flow):
if (
"block_config" in mmlu_flow[i]
and "batch_kwargs" in mmlu_flow[i]["block_config"]
):
mmlu_flow[i]["block_config"]["batch_kwargs"]["batched"] = False
logger.debug(mmlu_flow[i])
else:
logger.debug("No batch_kwargs in mmlu_flow: %s" % mmlu_flow[i])
mmlu_flow[i]["block_config"]["logger"] = logger
knowledge_flow = SynthKnowledgeFlow(client, model_name).get_flow()
for i, _ in enumerate(knowledge_flow):
if (
"block_config" in knowledge_flow[i]
and "batch_kwargs" in knowledge_flow[i]["block_config"]
):
knowledge_flow[i]["block_config"]["batch_kwargs"]["batched"] = False
logger.debug(knowledge_flow[i])
else:
logger.debug(
"No batch_kwargs in knowledge_flow: %s" % knowledge_flow[i]
)
knowledge_flow[i]["block_config"]["logger"] = logger
knowledge_pipe = Pipeline(knowledge_flow)
mmlu_pipe = Pipeline(mmlu_flow)
sdg = SDG([mmlu_pipe, knowledge_pipe])
elif profile == "simple":
knowledge_flow = SimpleKnowledgeFlow(client, model_name).get_flow()
# TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI
# about whether we can turn this on (whether vllm is used or not)
for i, _ in enumerate(knowledge_flow):
if (
"block_config" in knowledge_flow[i]
and "batch_kwargs" in knowledge_flow[i]["block_config"]
):
knowledge_flow[i]["block_config"]["batch_kwargs"]["batched"] = False
logger.debug(knowledge_flow[i])
else:
logger.debug(
"No batch_kwargs in knowledge_flow: %s" % knowledge_flow[i]
)
knowledge_flow[i]["block_config"]["logger"] = logger
sdg = SDG([Pipeline(knowledge_flow)])
else:
raise SystemExit(f"Error: profile ({profile}) is not supported.")

if console_output:
logger.info(
Expand All @@ -179,6 +230,7 @@ def generate_data(
continue

samples = [{}]
# pylint: disable=consider-using-enumerate
for i in range(len(leaf_node)):
samples[-1].setdefault("task_description", leaf_node[i]["task_description"])
samples[-1].setdefault("document", leaf_node[i]["document"])
Expand Down Expand Up @@ -209,17 +261,22 @@ def generate_data(
]

# TODO this is broken, just trying to get initial integration to run
# pylint: disable=consider-using-enumerate
for i in range(len(samples)):
samples[i]["document"] = utils.chunk_document(
documents=samples[i]["document"],
server_ctx_size=server_ctx_size,
chunk_word_count=chunk_word_count,
)[0]

# TODO -- there is a parameter for how many samples to generate, but we ignore it so far

ds = Dataset.from_list(samples)
generated_data = sdg.generate(ds)
logger.info("Generated %d samples" % len(generated_data))
logger.debug("Generated data: %s" % generated_data)

_gen_train_data(logger, generated_data, os.path.join(output_dir, output_file))

generate_duration = time.time() - generate_start
logger.info(f"Generation took {generate_duration:.2f}s")
12 changes: 7 additions & 5 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,21 @@ def __init__(
def _parse(self, generated_string) -> dict:
matches = {}
for start_tag, end_tag, output_col in zip(
self.block_config["start_tags"],
self.block_config["end_tags"],
self.block_config.get("start_tags", []),
self.block_config.get("end_tags", []),
self.output_cols,
):
if not start_tag and not end_tag:
matches[output_col] = (
matches[output_col] = [
generated_string.strip() if generated_string else None
)
]
else:
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
all_matches = re.findall(pattern, generated_string, re.DOTALL)
matches[output_col] = (
[match.strip() for match in all_matches] if all_matches else None
)

logger.debug("_parse() matches: {}".format(matches))
return matches

def _generate(self, samples, **gen_kwargs) -> list:
Expand All @@ -82,6 +82,7 @@ def generate(self, samples, **gen_kwargs) -> Dataset:
"""
num_samples = self.batch_params.get("num_samples", None)
batched = self.batch_params.get("batched", False)
logger.debug("Generating outputs for {} samples".format(len(samples)))

if (num_samples is not None) and ("num_samples" not in samples.column_names):
samples = samples.add_column("num_samples", [num_samples] * len(samples))
Expand All @@ -97,6 +98,7 @@ def generate(self, samples, **gen_kwargs) -> Dataset:
outputs = self._generate(samples, **gen_kwargs)
else:
outputs = [self._generate([sample], **gen_kwargs)[0] for sample in samples]
logger.debug("Generated outputs: {}".format(outputs))

new_data = []
for sample, output in zip(samples, outputs):
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[s
def get_model_family(forced, model_path):
forced = MODEL_FAMILY_MAPPINGS.get(forced, forced)
if forced and forced.lower() not in MODEL_FAMILIES:
raise Exception("Unknown model family: %s" % forced)
raise GenerateException("Unknown model family: %s" % forced)

# Try to guess the model family based on the model's filename
guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower()
Expand Down

0 comments on commit 7edbc76

Please sign in to comment.