diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 3f023a12..2c812361 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -25,7 +25,9 @@ SimpleFreeformSkillFlow, SimpleGroundedSkillFlow, SimpleKnowledgeFlow, + SynthGroundedSkillsFlow, SynthKnowledgeFlow, + SynthSkillsFlow, ) from instructlab.sdg.pipeline import Pipeline from instructlab.sdg.utils import chunking, models @@ -122,19 +124,21 @@ def _gen_test_data( outfile.write("\n") -def _sdg_init(profile, client, model_family, model_name, num_iters, batched): +def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched): knowledge_flow_types = [] freeform_skill_flow_types = [] grounded_skill_flow_types = [] - if profile == "full": + if pipeline == "full": knowledge_flow_types.append(MMLUBenchFlow) knowledge_flow_types.append(SynthKnowledgeFlow) - elif profile == "simple": + freeform_skill_flow_types.append(SynthSkillsFlow) + grounded_skill_flow_types.append(SynthGroundedSkillsFlow) + elif pipeline == "simple": knowledge_flow_types.append(SimpleKnowledgeFlow) freeform_skill_flow_types.append(SimpleFreeformSkillFlow) grounded_skill_flow_types.append(SimpleGroundedSkillFlow) else: - raise utils.GenerateException(f"Error: profile ({profile}) is not supported.") + raise utils.GenerateException(f"Error: pipeline ({pipeline}) is not supported.") sdg_knowledge = SDG( [ @@ -204,8 +208,8 @@ def generate_data( tls_client_cert: Optional[str] = None, tls_client_key: Optional[str] = None, tls_client_passwd: Optional[str] = None, - # TODO need to update the CLI to specify which profile to use (simple or full at the moment) - profile: Optional[str] = "simple", + # TODO need to update the CLI to specify which pipeline to use (simple or full at the moment) + pipeline: Optional[str] = "simple", ): generate_start = time.time() @@ -251,7 +255,12 @@ def generate_data( batched = False sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init( - profile, client, model_family, model_name, num_instructions_to_generate, batched + pipeline, + client, + model_family, + model_name, + num_instructions_to_generate, + batched, ) if console_output: