Skip to content

Commit

Permalink
Fix prompt file paths for an installed library (#67)
Browse files Browse the repository at this point in the history
* default_flows: Move sdg_base into base Flow class

Each subclass was calculating the root directory of the python package
for finding prompt templates that are embedded in the package. Move
this to the base `Flow` class so we're only doing it in one place.

Signed-off-by: Russell Bryant <[email protected]>

* Fix remaining prompt file paths

Special handling is required to find these prompt files from an
installed version of the instructlab-sdg package. This updates the
rest of this file to make use of `self.sdg_base` from the base `Flow`
class which will work from an installed library.

Signed-off-by: Russell Bryant <[email protected]>

---------

Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb authored Jul 3, 2024
1 parent 941122c commit afbea4c
Showing 1 changed file with 47 additions and 22 deletions.
69 changes: 47 additions & 22 deletions src/instructlab/sdg/default_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, client, model_family, model_id, num_iters, batched=True) -> N
self.model_id = model_id
self.num_iters = num_iters
self.batched = batched
self.sdg_base = resources.files(__package__)

@abstractmethod
def get_flow(self) -> list:
Expand Down Expand Up @@ -76,9 +77,8 @@ def get_flow(self) -> list:
class SimpleKnowledgeFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
sdg_base = resources.files(__package__)
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
sdg_base, "configs/knowledge/simple_generate_qa.yaml"
self.sdg_base, "configs/knowledge/simple_generate_qa.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_knowledge"
flow[0]["block_config"]["block_name"] = "gen_knowledge"
Expand All @@ -88,9 +88,8 @@ def get_flow(self) -> list:
class SimpleFreeformSkillFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
sdg_base = resources.files(__package__)
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
sdg_base, "configs/skills/simple_generate_qa_freeform.yaml"
self.sdg_base, "configs/skills/simple_generate_qa_freeform.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_freeform"
flow[0]["block_config"]["block_name"] = "gen_skill_freeform"
Expand All @@ -100,9 +99,8 @@ def get_flow(self) -> list:
class SimpleGroundedSkillFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
sdg_base = resources.files(__package__)
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
sdg_base, "configs/skills/simple_generate_qa_grounded.yaml"
self.sdg_base, "configs/skills/simple_generate_qa_grounded.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_grounded"
flow[0]["block_config"]["block_name"] = "gen_skill_grounded"
Expand All @@ -111,14 +109,14 @@ def get_flow(self) -> list:

class MMLUBenchFlow(Flow):
def get_flow(self) -> list:
sdg_base = resources.files(__package__)
self.sdg_base = resources.files(__package__)
return [
{
"block_type": LLMBlock,
"block_config": {
"block_name": "gen_mmlu_knowledge",
"config_path": os.path.join(
sdg_base, "configs/knowledge/mcq_generation.yaml"
self.sdg_base, "configs/knowledge/mcq_generation.yaml"
),
"client": self.client,
"model_id": self.model_id,
Expand All @@ -140,14 +138,14 @@ def get_flow(self) -> list:

class SynthKnowledgeFlow(Flow):
def get_flow(self) -> list:
sdg_base = resources.files(__package__)
return [
{
"block_type": LLMBlock,
"block_config": {
"block_name": "gen_knowledge",
"config_path": os.path.join(
sdg_base, "configs/knowledge/generate_questions_responses.yaml"
self.sdg_base,
"configs/knowledge/generate_questions_responses.yaml",
),
"client": self.client,
"model_id": self.model_id,
Expand All @@ -173,7 +171,7 @@ def get_flow(self) -> list:
"block_config": {
"block_name": "eval_faithfulness_qa_pair",
"config_path": os.path.join(
sdg_base, "configs/knowledge/evaluate_faithfulness.yaml"
self.sdg_base, "configs/knowledge/evaluate_faithfulness.yaml"
),
"client": self.client,
"model_id": self.model_id,
Expand Down Expand Up @@ -206,7 +204,7 @@ def get_flow(self) -> list:
"block_config": {
"block_name": "eval_relevancy_qa_pair",
"config_path": os.path.join(
sdg_base, "configs/knowledge/evaluate_relevancy.yaml"
self.sdg_base, "configs/knowledge/evaluate_relevancy.yaml"
),
"client": self.client,
"model_id": self.model_id,
Expand Down Expand Up @@ -240,7 +238,7 @@ def get_flow(self) -> list:
"block_config": {
"block_name": "eval_verify_question",
"config_path": os.path.join(
sdg_base, "configs/knowledge/evaluate_question.yaml"
self.sdg_base, "configs/knowledge/evaluate_question.yaml"
),
"client": self.client,
"model_id": self.model_id,
Expand Down Expand Up @@ -279,7 +277,10 @@ def get_flow(self) -> list:
"block_type": LLMBlock,
"block_config": {
"block_name": "gen_questions",
"config_path": "src/instructlab/sdg/configs/skills/freeform_questions.yaml",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/freeform_questions.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
Expand All @@ -296,7 +297,10 @@ def get_flow(self) -> list:
"block_type": LLMBlock,
"block_config": {
"block_name": "eval_questions",
"config_path": "src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/evaluate_freeform_questions.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
Expand Down Expand Up @@ -325,7 +329,10 @@ def get_flow(self) -> list:
"block_type": LLMBlock,
"block_config": {
"block_name": "gen_responses",
"config_path": "src/instructlab/sdg/configs/skills/freeform_responses.yaml",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/freeform_responses.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
Expand All @@ -340,7 +347,10 @@ def get_flow(self) -> list:
"block_type": LLMBlock,
"block_config": {
"block_name": "evaluate_qa_pair",
"config_path": "src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/evaluate_freeform_pair.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
Expand Down Expand Up @@ -379,7 +389,10 @@ def get_flow(self) -> list:
"block_type": LLMBlock,
"block_kwargs": {
"block_name": "gen_contexts",
"config_path": "src/instructlab/sdg/configs/skills/contexts.yaml",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/contexts.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
Expand All @@ -399,7 +412,10 @@ def get_flow(self) -> list:
"block_type": LLMBlock,
"block_config": {
"block_name": "gen_grounded_questions",
"config_path": "src/instructlab/sdg/configs/skills/grounded_questions.yaml",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/grounded_questions.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
Expand All @@ -415,7 +431,10 @@ def get_flow(self) -> list:
"block_type": LLMBlock,
"block_config": {
"block_name": "eval_grounded_questions",
"config_path": "src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/evaluate_grounded_questions.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
Expand Down Expand Up @@ -445,7 +464,10 @@ def get_flow(self) -> list:
"block_type": LLMBlock,
"block_config": {
"block_name": "gen_grounded_responses",
"config_path": "src/instructlab/sdg/configs/skills/grounded_responses.yaml",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/grounded_responses.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
Expand All @@ -460,7 +482,10 @@ def get_flow(self) -> list:
"block_type": LLMBlock,
"block_config": {
"block_name": "evaluate_grounded_qa_pair",
"config_path": "src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/evaluate_grounded_pair.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
Expand Down

0 comments on commit afbea4c

Please sign in to comment.