Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for auxiliary dataset generation #204

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/instructlab/sdg/configs/knowledge/spellcheck.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
system: You are an AI assistant that is an expert at fixing spelling errors in documents.

introduction: |
Give me a copy of the below document with all spelling errors corrected.

principles: |
Do not add any new information.
Do not leave out any information.

examples: ""

generation: |
Document:
{document}

start_tags: [""]
end_tags: [""]
104 changes: 96 additions & 8 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Standard
from typing import Optional
from typing import Dict, List, Optional
import json
import logging
import os.path
Expand All @@ -12,6 +12,7 @@

# First Party
from instructlab.sdg.utils import GenerateException, pandas
from instructlab.sdg.utils.pandas import dataset_from_pandas_dataframe

ALLOWED_COLS = ["id", "messages", "metadata"]
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -374,7 +375,68 @@ def _conv_pretrain(rec):
return rec


def _create_phase10_ds(generated_dataset: Dataset):
def _create_auxiliary_dataset(
generated_dataset: Dataset, auxiliary_inst: Optional[Dict[str, List[str]]]
):
# Samples that went through the auxiliary generation pipeline will
# have a dataset_type column created by that pipeline. If that's
# not present, then we may be running in a pipeline without any
# auxiliary dataset generation enabled.
if "dataset_type" not in generated_dataset.column_names:
return None
# If we didn't find any auxiliary instructions to load, then
# that's also another sign that we're not running with any
# auxiliary datasets enabled.
if auxiliary_inst is None:
return None
# This "base_document" dataset_type is set in the knowledge
# pipeline config, and represents samples that do not have the
# auxiliary generated document attached, so we filter those out.
auxiliary_ds = generated_dataset.filter(
lambda x: x["dataset_type"] != "base_document"
bbrowning marked this conversation as resolved.
Show resolved Hide resolved
)
unique_document_auxiliary = auxiliary_ds.to_pandas().drop_duplicates(
subset=["document"]
)
unique_document_auxiliary = dataset_from_pandas_dataframe(unique_document_auxiliary)
unique_document_auxiliary = unique_document_auxiliary.select_columns(
[
"raw_document",
"document_outline",
"domain",
"dataset_type",
"document",
]
)
unique_document_auxiliary = unique_document_auxiliary.rename_columns(
{"raw_document": "context", "document": "response"}
)

def __create_auxiliary_ds(rec):
instruction = random.choice(auxiliary_inst[rec["dataset_type"]])
messages = [
{"role": "user", "content": f"{rec['context']}\n\n{instruction}"},
{"role": "assistant", "content": rec["response"]},
]
metadata = json.dumps(
{
"dataset_type": rec["dataset_type"],
"raw_document": rec["context"],
"dataset": f"document_{rec['dataset_type']}",
"domain": rec["domain"],
}
)
return {"messages": messages, "metadata": metadata, "id": str(uuid.uuid4())}

unique_document_auxiliary = unique_document_auxiliary.map(
__create_auxiliary_ds, remove_columns=unique_document_auxiliary.column_names
)
return unique_document_auxiliary


def _create_phase10_ds(
generated_dataset: Dataset, auxiliary_inst: Optional[Dict[str, List[str]]]
):
"""
Create a dataset for Phase 1.0 of downstream training.

Expand All @@ -387,10 +449,17 @@ def _create_phase10_ds(generated_dataset: Dataset):
)
knowledge_ds = _add_extra_contexts_to_samples(knowledge_ds, p=0.4)

return knowledge_ds
auxiliary_dataset = _create_auxiliary_dataset(generated_dataset, auxiliary_inst)
if auxiliary_dataset is not None:
phase10 = concatenate_datasets([knowledge_ds, auxiliary_dataset])
else:
phase10 = knowledge_ds
return phase10


def _create_phase07_ds(generated_dataset: Dataset):
def _create_phase07_ds(
generated_dataset: Dataset, auxiliary_inst: Optional[Dict[str, List[str]]]
):
"""
Create a dataset for Phase 0.7 of downstream training.

Expand All @@ -404,7 +473,13 @@ def _create_phase07_ds(generated_dataset: Dataset):
)
knowledge_ds = knowledge_ds.map(_conv_pretrain)

return knowledge_ds
auxiliary_dataset = _create_auxiliary_dataset(generated_dataset, auxiliary_inst)
if auxiliary_dataset is not None:
auxiliary_dataset = auxiliary_dataset.map(_conv_pretrain)
phase07 = concatenate_datasets([knowledge_ds, auxiliary_dataset])
else:
phase07 = knowledge_ds
return phase07


def _convert_to_leaf_node_messages(sample: dict, sys_prompt: str):
Expand Down Expand Up @@ -440,12 +515,21 @@ class DataMixer:
# once.
NUM_SYNTH_SKILLS = 30

def __init__(self, data_dirs, output_dir, date_suffix, sys_prompt, num_procs):
def __init__(
self,
data_dirs,
output_dir,
date_suffix,
sys_prompt,
num_procs,
auxiliary_inst=None,
):
self.data_dirs = data_dirs
self.output_dir = output_dir
self.sys_prompt = sys_prompt
self.date_suffix = date_suffix
self.num_procs = num_procs
self.auxiliary_inst = auxiliary_inst

self.knowledge_recipe = self._load_default_recipe("knowledge.yaml")
self.skills_recipe = self._load_default_recipe("skills.yaml")
Expand Down Expand Up @@ -482,7 +566,9 @@ def _gen_leaf_node_data(

def collect(self, leaf_node_path, new_generated_data, is_knowledge):
if is_knowledge:
knowledge_phase_data = _create_phase07_ds(new_generated_data)
knowledge_phase_data = _create_phase07_ds(
new_generated_data, self.auxiliary_inst
)
output_file_leaf_knowledge = (
f"node_datasets_{self.date_suffix}/{leaf_node_path}_p07.jsonl"
)
Expand All @@ -492,7 +578,9 @@ def collect(self, leaf_node_path, new_generated_data, is_knowledge):
output_file_leaf_knowledge,
)

skills_phase_data = _create_phase10_ds(new_generated_data)
skills_phase_data = _create_phase10_ds(
new_generated_data, self.auxiliary_inst
)
output_file_leaf_skills = (
f"node_datasets_{self.date_suffix}/{leaf_node_path}_p10.jsonl"
)
Expand Down
8 changes: 6 additions & 2 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def load_pipeline(yaml_basename):
)


def _mixer_init(ctx, output_dir, date_suffix):
def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst):
pd = platformdirs.PlatformDirs(
appname=os.path.join("instructlab", "sdg"), multipath=True
)
Expand All @@ -258,6 +258,7 @@ def _mixer_init(ctx, output_dir, date_suffix):
date_suffix,
_SYS_PROMPT,
ctx.dataset_num_procs,
knowledge_auxiliary_inst,
)


Expand Down Expand Up @@ -367,7 +368,10 @@ def generate_data(
mmlu_ctx = dataclasses.replace(ctx, checkpoint_dir=None)
mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx)

mixer = _mixer_init(ctx, output_dir, date_suffix)
# FIXME: remove SDG https://github.com/instructlab/sdg/pull/64
mixer = _mixer_init(
ctx, output_dir, date_suffix, sdg_knowledge.pipelines[0].auxiliary_inst
)

if console_output:
logger.info(
Expand Down
13 changes: 10 additions & 3 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from importlib import resources
from typing import Iterable, Optional
from typing import Dict, Iterable, List, Optional
import logging
import math
import os.path
Expand Down Expand Up @@ -109,6 +109,7 @@ def __init__(
ctx: PipelineContext,
config_path: str,
chained_blocks: list[dict],
auxiliary_inst: Optional[Dict[str, List[str]]] = None,
) -> None:
"""
Initialize the Pipeline class with a configuration dictionary.
Expand All @@ -120,12 +121,14 @@ def __init__(
self.config_path = config_path
# pipeline config is the run configuration that consists of the pipeline steps
self.chained_blocks = chained_blocks
# datamixing instructions for auxiliary data generated by this pipeline
self.auxiliary_inst = auxiliary_inst

@classmethod
def from_file(cls, ctx, pipeline_yaml):
if not os.path.isabs(pipeline_yaml):
pipeline_yaml = os.path.join(resources.files(__package__), pipeline_yaml)
return cls(ctx, pipeline_yaml, _parse_pipeline_config_file(pipeline_yaml))
return cls(ctx, pipeline_yaml, *_parse_pipeline_config_file(pipeline_yaml))

def generate(self, dataset) -> Dataset:
"""
Expand Down Expand Up @@ -296,7 +299,11 @@ def _parse_pipeline_config_file(pipeline_yaml):
"The pipeline config file contains no 'blocks' section"
)

return content["blocks"]
auxiliary_inst = None
if "datamixing" in content and "auxiliary_instructions" in content["datamixing"]:
auxiliary_inst = content["datamixing"]["auxiliary_instructions"]

return content["blocks"], auxiliary_inst


# This is part of the public API.
Expand Down
37 changes: 37 additions & 0 deletions src/instructlab/sdg/pipelines/full/knowledge.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
version: "1.0"
blocks:
- name: duplicate_document_col
type: DuplicateColumnsBlock
config:
columns_map:
document: base_document

- name: gen_spellcheck
type: LLMBlock
config:
config_path: ../../configs/knowledge/spellcheck.yaml
output_cols:
- spellcheck
gen_kwargs:
max_tokens: 2048

- name: flatten_auxiliary_columns
type: FlattenColumnsBlock
config:
var_cols:
- spellcheck
- base_document
value_name: corrected_document
var_name: dataset_type

- name: rename_to_document_column
type: RenameColumnsBlock
config:
columns_map:
document: raw_document
corrected_document: document
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me nervous - columns_map is a dict ... what's guaranteeing the ordering here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the datasets library:

        def rename(columns):
            return [column_mapping[col] if col in column_mapping else col for col in columns]

i.e. it's applying these in the order the columns in the dataset

Ok, at least that's deterministic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, and good catch - that seems quite fragile, and this config should probably be refactored to be an array? Or this split out into two steps, to take out any doubt of the ordering the renaming is applied.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be done after merging this larger PR, since the order happens to be deterministic. Especially because it may warrant a second look at the API exposed by RenameColumnsBlock.


- name: gen_knowledge
type: LLMBlock
config:
Expand Down Expand Up @@ -73,3 +104,9 @@ blocks:
- explanation
- rating
- __index_level_0__

datamixing:
auxiliary_instructions:
spellcheck:
- Correct any spelling errors in the document and output the corrected version.
- Rewrite the document to remove any spelling errors.
17 changes: 17 additions & 0 deletions src/instructlab/sdg/pipelines/schema/v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,23 @@
}
}
}
},
"datamixing": {
"type": "object",
"additionalProperties": false,
"properties": {
"auxiliary_instructions": {
"type": "object",
"patternProperties": {
".*": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
}
}
}
}
6 changes: 6 additions & 0 deletions tests/test_default_pipeline_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from instructlab.sdg.pipeline import Pipeline, PipelineContext
from instructlab.sdg.utilblocks import (
CombineColumnsBlock,
DuplicateColumnsBlock,
FlattenColumnsBlock,
RenameColumnsBlock,
SamplePopulatorBlock,
SelectorBlock,
)
Expand All @@ -23,8 +26,11 @@ def _noop_generate(self, samples):

@patch.object(CombineColumnsBlock, "generate", _noop_generate)
@patch.object(ConditionalLLMBlock, "generate", _noop_generate)
@patch.object(DuplicateColumnsBlock, "generate", _noop_generate)
@patch.object(FilterByValueBlock, "generate", _noop_generate)
@patch.object(FlattenColumnsBlock, "generate", _noop_generate)
@patch.object(LLMBlock, "generate", _noop_generate)
@patch.object(RenameColumnsBlock, "generate", _noop_generate)
@patch.object(SamplePopulatorBlock, "generate", _noop_generate)
@patch.object(SelectorBlock, "generate", _noop_generate)
@patch("instructlab.sdg.llmblock.server_supports_batched", lambda c, m: True)
Expand Down