Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support converting messages datasets into multiple pre-training formats #341

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,24 @@ def __pick_documents(rec, p):
return ds


def _conv_pretrain(rec):
def _conv_pretrain(rec, use_legacy_pretraining_format: bool):
"""
Convert a messages dataset that contains only user/assistant entries per
message (and in that order) to a pretraining message used downstream by
the training pipeline. `_generate_knowledge_qa_dataset` creates the type
of dataset expected here.
"""
if use_legacy_pretraining_format:
user = "<|user|>"
assistant = "<|assistant|>"
else:
user = "<|start_of_role|>user<|end_of_role|>"
assistant = "<|start_of_role|>assistant<|end_of_role|>"

rec["messages"] = [
{
"role": "pretraining",
"content": f"<|user|>\n{rec['messages'][0]['content']}\n<|assistant|>\n{rec['messages'][1]['content']}",
"content": f"{user}\n{rec['messages'][0]['content']}\n{assistant}\n{rec['messages'][1]['content']}",
}
]
return rec
Expand Down Expand Up @@ -461,7 +468,9 @@ def _create_phase10_ds(


def _create_phase07_ds(
generated_dataset: Dataset, auxiliary_inst: Optional[Dict[str, List[str]]]
generated_dataset: Dataset,
auxiliary_inst: Optional[Dict[str, List[str]]],
use_legacy_pretraining_format: bool,
):
"""
Create a dataset for Phase 0.7 of downstream training.
Expand All @@ -474,11 +483,15 @@ def _create_phase07_ds(
knowledge_ds = _generate_knowledge_qa_dataset(
generated_dataset, keep_context_separate=False
)
knowledge_ds = knowledge_ds.map(_conv_pretrain)
knowledge_ds = knowledge_ds.map(
lambda rec: _conv_pretrain(rec, use_legacy_pretraining_format)
)

auxiliary_dataset = _create_auxiliary_dataset(generated_dataset, auxiliary_inst)
if auxiliary_dataset is not None:
auxiliary_dataset = auxiliary_dataset.map(_conv_pretrain)
auxiliary_dataset = auxiliary_dataset.map(
lambda rec: _conv_pretrain(rec, use_legacy_pretraining_format)
)
phase07 = concatenate_datasets([knowledge_ds, auxiliary_dataset])
else:
phase07 = knowledge_ds
Expand Down Expand Up @@ -567,10 +580,16 @@ def _gen_leaf_node_data(
leaf_node_data.to_json(output_file, orient="records", lines=True)
recipe.add_dataset(output_file_leaf_node, sampling_size)

def collect(self, leaf_node_path, new_generated_data, is_knowledge):
def collect(
self,
leaf_node_path,
new_generated_data,
is_knowledge,
use_legacy_pretraining_format,
):
if is_knowledge:
knowledge_phase_data = _create_phase07_ds(
new_generated_data, self.auxiliary_inst
new_generated_data, self.auxiliary_inst, use_legacy_pretraining_format
)
output_file_leaf_knowledge = (
f"node_datasets_{self.date_suffix}/{leaf_node_path}_p07.jsonl"
Expand Down
8 changes: 7 additions & 1 deletion src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def generate_data(
client: openai.OpenAI,
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
system_prompt: Optional[str] = None,
use_legacy_pretraining_format: Optional[bool] = True,
model_family: Optional[str] = None,
model_name: Optional[str] = None,
num_cpus: Optional[int] = None,
Expand Down Expand Up @@ -423,7 +424,12 @@ def generate_data(
date_suffix,
)

mixer.collect(leaf_node_path, new_generated_data, is_knowledge)
mixer.collect(
leaf_node_path,
new_generated_data,
is_knowledge,
use_legacy_pretraining_format,
)

if generated_data is None:
generated_data = []
Expand Down
Loading