From afbea4c7c2830febf46802ce35d586f87b5dc018 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 3 Jul 2024 10:44:59 -0400 Subject: [PATCH] Fix prompt file paths for an installed library (#67) * default_flows: Move sdg_base into base Flow class Each subclass was calculating the root directory of the python package for finding prompt templates that are embedded in the package. Move this to the base `Flow` class so we're only doing it in one place. Signed-off-by: Russell Bryant * Fix remaining prompt file paths Special handling is required to find these prompt files from an installed version of the instructlab-sdg package. This updates the rest of this file to make use of `self.sdg_base` from the base `Flow` class which will work from an installed library. Signed-off-by: Russell Bryant --------- Signed-off-by: Russell Bryant --- src/instructlab/sdg/default_flows.py | 69 +++++++++++++++++++--------- 1 file changed, 47 insertions(+), 22 deletions(-) diff --git a/src/instructlab/sdg/default_flows.py b/src/instructlab/sdg/default_flows.py index 31edd3d6..f83ed60b 100644 --- a/src/instructlab/sdg/default_flows.py +++ b/src/instructlab/sdg/default_flows.py @@ -36,6 +36,7 @@ def __init__(self, client, model_family, model_id, num_iters, batched=True) -> N self.model_id = model_id self.num_iters = num_iters self.batched = batched + self.sdg_base = resources.files(__package__) @abstractmethod def get_flow(self) -> list: @@ -76,9 +77,8 @@ def get_flow(self) -> list: class SimpleKnowledgeFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - sdg_base = resources.files(__package__) flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join( - sdg_base, "configs/knowledge/simple_generate_qa.yaml" + self.sdg_base, "configs/knowledge/simple_generate_qa.yaml" ) flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_knowledge" flow[0]["block_config"]["block_name"] = "gen_knowledge" @@ -88,9 +88,8 @@ def get_flow(self) -> list: class SimpleFreeformSkillFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - sdg_base = resources.files(__package__) flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join( - sdg_base, "configs/skills/simple_generate_qa_freeform.yaml" + self.sdg_base, "configs/skills/simple_generate_qa_freeform.yaml" ) flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_freeform" flow[0]["block_config"]["block_name"] = "gen_skill_freeform" @@ -100,9 +99,8 @@ def get_flow(self) -> list: class SimpleGroundedSkillFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - sdg_base = resources.files(__package__) flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join( - sdg_base, "configs/skills/simple_generate_qa_grounded.yaml" + self.sdg_base, "configs/skills/simple_generate_qa_grounded.yaml" ) flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_grounded" flow[0]["block_config"]["block_name"] = "gen_skill_grounded" @@ -111,14 +109,14 @@ def get_flow(self) -> list: class MMLUBenchFlow(Flow): def get_flow(self) -> list: - sdg_base = resources.files(__package__) + self.sdg_base = resources.files(__package__) return [ { "block_type": LLMBlock, "block_config": { "block_name": "gen_mmlu_knowledge", "config_path": os.path.join( - sdg_base, "configs/knowledge/mcq_generation.yaml" + self.sdg_base, "configs/knowledge/mcq_generation.yaml" ), "client": self.client, "model_id": self.model_id, @@ -140,14 +138,14 @@ def get_flow(self) -> list: class SynthKnowledgeFlow(Flow): def get_flow(self) -> list: - sdg_base = resources.files(__package__) return [ { "block_type": LLMBlock, "block_config": { "block_name": "gen_knowledge", "config_path": os.path.join( - sdg_base, "configs/knowledge/generate_questions_responses.yaml" + self.sdg_base, + "configs/knowledge/generate_questions_responses.yaml", ), "client": self.client, "model_id": self.model_id, @@ -173,7 +171,7 @@ def get_flow(self) -> list: "block_config": { "block_name": "eval_faithfulness_qa_pair", "config_path": os.path.join( - sdg_base, "configs/knowledge/evaluate_faithfulness.yaml" + self.sdg_base, "configs/knowledge/evaluate_faithfulness.yaml" ), "client": self.client, "model_id": self.model_id, @@ -206,7 +204,7 @@ def get_flow(self) -> list: "block_config": { "block_name": "eval_relevancy_qa_pair", "config_path": os.path.join( - sdg_base, "configs/knowledge/evaluate_relevancy.yaml" + self.sdg_base, "configs/knowledge/evaluate_relevancy.yaml" ), "client": self.client, "model_id": self.model_id, @@ -240,7 +238,7 @@ def get_flow(self) -> list: "block_config": { "block_name": "eval_verify_question", "config_path": os.path.join( - sdg_base, "configs/knowledge/evaluate_question.yaml" + self.sdg_base, "configs/knowledge/evaluate_question.yaml" ), "client": self.client, "model_id": self.model_id, @@ -279,7 +277,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_questions", - "config_path": "src/instructlab/sdg/configs/skills/freeform_questions.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/freeform_questions.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -296,7 +297,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "eval_questions", - "config_path": "src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/evaluate_freeform_questions.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -325,7 +329,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_responses", - "config_path": "src/instructlab/sdg/configs/skills/freeform_responses.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/freeform_responses.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -340,7 +347,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "evaluate_qa_pair", - "config_path": "src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/evaluate_freeform_pair.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -379,7 +389,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_kwargs": { "block_name": "gen_contexts", - "config_path": "src/instructlab/sdg/configs/skills/contexts.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/contexts.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -399,7 +412,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_grounded_questions", - "config_path": "src/instructlab/sdg/configs/skills/grounded_questions.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/grounded_questions.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -415,7 +431,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "eval_grounded_questions", - "config_path": "src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/evaluate_grounded_questions.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -445,7 +464,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_grounded_responses", - "config_path": "src/instructlab/sdg/configs/skills/grounded_responses.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/grounded_responses.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -460,7 +482,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "evaluate_grounded_qa_pair", - "config_path": "src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/evaluate_grounded_pair.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family),