diff --git a/scripts/test_freeform_skills.py b/scripts/test_freeform_skills.py index a8612c09..058fd64f 100644 --- a/scripts/test_freeform_skills.py +++ b/scripts/test_freeform_skills.py @@ -5,7 +5,7 @@ # First Party from src.instructlab.sdg import SDG from src.instructlab.sdg.default_flows import SynthSkillsFlow -from src.instructlab.sdg.pipeline import Pipeline +from src.instructlab.sdg.pipeline import Pipeline, PipelineContext # for vLLM endpoints, the api_key remains "EMPTY" openai_api_key = "EMPTY" @@ -49,7 +49,9 @@ ds = Dataset.from_list(samples) -skills_flow = SynthSkillsFlow(client, "mixtral", teacher_model, 1).get_flow() +ctx = PipelineContext(client, "mixtral", teacher_model, 1) + +skills_flow = SynthSkillsFlow(ctx).get_flow() skills_pipe = Pipeline(skills_flow) sdg = SDG([skills_pipe]) diff --git a/scripts/test_grounded_skills.py b/scripts/test_grounded_skills.py index 338edb6c..6d0bdc1b 100644 --- a/scripts/test_grounded_skills.py +++ b/scripts/test_grounded_skills.py @@ -5,7 +5,7 @@ # First Party from src.instructlab.sdg import SDG from src.instructlab.sdg.default_flows import SynthGroundedSkillsFlow -from src.instructlab.sdg.pipeline import Pipeline +from src.instructlab.sdg.pipeline import Pipeline, PipelineContext # for vLLM endpoints, the api_key remains "EMPTY" openai_api_key = "EMPTY" @@ -97,7 +97,9 @@ ds = Dataset.from_list(samples) -skills_flow = SynthGroundedSkillsFlow(client, "mixtral", teacher_model, 10).get_flow() +ctx = PipelineContext(client, "mixtral", teacher_model, 10) + +skills_flow = SynthGroundedSkillsFlow(ctx).get_flow() skills_pipe = Pipeline(skills_flow) sdg = SDG([skills_pipe]) diff --git a/scripts/test_knowledge.py b/scripts/test_knowledge.py index aeedcf59..2b534903 100644 --- a/scripts/test_knowledge.py +++ b/scripts/test_knowledge.py @@ -8,7 +8,7 @@ # First Party from src.instructlab.sdg import SDG from src.instructlab.sdg.default_flows import MMLUBenchFlow, SynthKnowledgeFlow -from src.instructlab.sdg.pipeline import Pipeline +from src.instructlab.sdg.pipeline import Pipeline, PipelineContext # Please don't add you vLLM endpoint key here openai_api_key = "EMPTY" @@ -38,12 +38,13 @@ ds = Dataset.from_list(samples) -mmlu_flow = MMLUBenchFlow(client, "mixtral", teacher_model, 1).get_flow() -knowledge_flow = SynthKnowledgeFlow(client, "mixtral", teacher_model, 1).get_flow() -knowledge_pipe = Pipeline(knowledge_flow) -mmlu_pipe = Pipeline(mmlu_flow) +ctx = PipelineContext(client, "mixtral", teacher_model, 1) -sdg = SDG([mmlu_pipe, knowledge_pipe]) +mmlu_flow = MMLUBenchFlow(ctx).get_flow() +knowledge_flow = SynthKnowledgeFlow(ctx).get_flow() +knowledge_pipe = Pipeline(mmlu_flow + knowledge_flow) + +sdg = SDG([knowledge_pipe]) mmlubench_data = sdg.generate(ds) print(mmlubench_data) diff --git a/src/instructlab/sdg/block.py b/src/instructlab/sdg/block.py index 09433f55..a28136c4 100644 --- a/src/instructlab/sdg/block.py +++ b/src/instructlab/sdg/block.py @@ -3,6 +3,7 @@ from abc import ABC from collections import ChainMap from typing import Any, Dict, Union +import os.path # Third Party import yaml @@ -14,7 +15,8 @@ class Block(ABC): - def __init__(self, block_name: str) -> None: + def __init__(self, ctx, block_name: str) -> None: + self.ctx = ctx self.block_name = block_name @staticmethod @@ -41,8 +43,13 @@ def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]: """ Load the configuration file for this block. + If the supplied configuration file is a relative path, it is assumed + to be part of this Python package. + :param config_path: The path to the configuration file. :return: The loaded configuration. """ + if not os.path.isabs(config_path): + config_path = os.path.join(self.ctx.sdg_base, config_path) with open(config_path, "r", encoding="utf-8") as config_file: return yaml.safe_load(config_file) diff --git a/src/instructlab/sdg/default_flows.py b/src/instructlab/sdg/default_flows.py index 818c4972..2839e212 100644 --- a/src/instructlab/sdg/default_flows.py +++ b/src/instructlab/sdg/default_flows.py @@ -1,42 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # Standard from abc import ABC, abstractmethod -from importlib import resources import operator -import os # Local from .filterblock import FilterByValueBlock from .llmblock import LLMBlock from .utilblocks import CombineColumnsBlock -MODEL_FAMILY_MIXTRAL = "mixtral" -MODEL_FAMILY_MERLINITE = "merlinite" - -_MODEL_PROMPT_MIXTRAL = " [INST] {prompt} [/INST]" -_MODEL_PROMPT_MERLINITE = "'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n'" - -_MODEL_PROMPTS = { - MODEL_FAMILY_MIXTRAL: _MODEL_PROMPT_MIXTRAL, - MODEL_FAMILY_MERLINITE: _MODEL_PROMPT_MERLINITE, -} - - -def _get_model_prompt(model_family): - if model_family not in _MODEL_PROMPTS: - raise ValueError(f"Unknown model family: {model_family}") - return _MODEL_PROMPTS[model_family] - class Flow(ABC): - def __init__( - self, client, model_family, model_id, num_instructions_to_generate - ) -> None: - self.client = client - self.model_family = model_family - self.model_id = model_id - self.num_instructions_to_generate = num_instructions_to_generate - self.sdg_base = resources.files(__package__) + def __init__(self, ctx) -> None: + self.ctx = ctx @abstractmethod def get_flow(self) -> list: @@ -51,15 +26,12 @@ def get_flow(self) -> list: "block_config": { "block_name": "", # must be set by subclass "config_path": "", # must be set by subclass - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["output"], }, "gen_kwargs": { "max_tokens": 2048, "temperature": 0.7, - "n": self.num_instructions_to_generate, + "n": self.ctx.num_instructions_to_generate, }, "drop_duplicates": ["output"], } @@ -69,8 +41,8 @@ def get_flow(self) -> list: class SimpleKnowledgeFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - flow[0]["block_config"]["config_path"] = os.path.join( - self.sdg_base, "configs/knowledge/simple_generate_qa.yaml" + flow[0]["block_config"]["config_path"] = ( + "configs/knowledge/simple_generate_qa.yaml" ) flow[0]["block_config"]["block_name"] = "gen_knowledge" return flow @@ -79,19 +51,18 @@ def get_flow(self) -> list: class SimpleFreeformSkillFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - flow[0]["block_config"]["config_path"] = os.path.join( - self.sdg_base, "configs/skills/simple_generate_qa_freeform.yaml" + flow[0]["block_config"]["config_path"] = ( + "configs/skills/simple_generate_qa_freeform.yaml" ) flow[0]["block_config"]["block_name"] = "gen_skill_freeform" - flow[0]["block_config"]["block_name"] = "gen_skill_freeform" return flow class SimpleGroundedSkillFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - flow[0]["block_config"]["config_path"] = os.path.join( - self.sdg_base, "configs/skills/simple_generate_qa_grounded.yaml" + flow[0]["block_config"]["config_path"] = ( + "configs/skills/simple_generate_qa_grounded.yaml" ) flow[0]["block_config"]["block_name"] = "gen_skill_grounded" return flow @@ -99,18 +70,12 @@ def get_flow(self) -> list: class MMLUBenchFlow(Flow): def get_flow(self) -> list: - self.sdg_base = resources.files(__package__) return [ { "block_type": LLMBlock, "block_config": { "block_name": "gen_mmlu_knowledge", - "config_path": os.path.join( - self.sdg_base, "configs/knowledge/mcq_generation.yaml" - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/knowledge/mcq_generation.yaml", "output_cols": ["mmlubench_question", "mmlubench_answer"], }, "gen_kwargs": { @@ -129,13 +94,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_knowledge", - "config_path": os.path.join( - self.sdg_base, - "configs/knowledge/generate_questions_responses.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/knowledge/generate_questions_responses.yaml", "output_cols": ["question", "response"], "parser_kwargs": { "parser_name": "custom", @@ -152,12 +111,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "eval_faithfulness_qa_pair", - "config_path": os.path.join( - self.sdg_base, "configs/knowledge/evaluate_faithfulness.yaml" - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/knowledge/evaluate_faithfulness.yaml", "output_cols": ["explanation", "judgment"], }, "gen_kwargs": { @@ -181,12 +135,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "eval_relevancy_qa_pair", - "config_path": os.path.join( - self.sdg_base, "configs/knowledge/evaluate_relevancy.yaml" - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/knowledge/evaluate_relevancy.yaml", "output_cols": ["feedback", "score"], }, "gen_kwargs": { @@ -211,12 +160,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "eval_verify_question", - "config_path": os.path.join( - self.sdg_base, "configs/knowledge/evaluate_question.yaml" - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/knowledge/evaluate_question.yaml", "output_cols": ["explanation", "rating"], }, "gen_kwargs": { @@ -247,16 +191,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_questions", - "config_path": os.path.join( - self.sdg_base, - "configs/skills/freeform_questions.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/skills/freeform_questions.yaml", "output_cols": ["question"], "batch_kwargs": { - "num_samples": self.num_instructions_to_generate, + "num_samples": self.ctx.num_instructions_to_generate, }, }, "drop_duplicates": ["question"], @@ -265,13 +203,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "eval_questions", - "config_path": os.path.join( - self.sdg_base, - "configs/skills/evaluate_freeform_questions.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/skills/evaluate_freeform_questions.yaml", "output_cols": ["evaluation", "score"], }, }, @@ -293,13 +225,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_responses", - "config_path": os.path.join( - self.sdg_base, - "configs/skills/freeform_responses.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/skills/freeform_responses.yaml", "output_cols": ["response"], }, }, @@ -307,13 +233,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "evaluate_qa_pair", - "config_path": os.path.join( - self.sdg_base, - "configs/skills/evaluate_freeform_pair.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/skills/evaluate_freeform_pair.yaml", "output_cols": ["evaluation", "score"], }, }, @@ -341,19 +261,13 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_contexts", - "config_path": os.path.join( - self.sdg_base, - "configs/skills/contexts.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/skills/contexts.yaml", "output_cols": ["context"], }, "gen_kwargs": { "temperature": 0.7, "max_tokens": 2048, - "n": self.num_instructions_to_generate, + "n": self.ctx.num_instructions_to_generate, }, "drop_duplicates": ["context"], }, @@ -361,13 +275,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_grounded_questions", - "config_path": os.path.join( - self.sdg_base, - "configs/skills/grounded_questions.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/skills/grounded_questions.yaml", "output_cols": ["question"], "batch_kwargs": { "num_samples": 3, @@ -379,13 +287,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "eval_grounded_questions", - "config_path": os.path.join( - self.sdg_base, - "configs/skills/evaluate_grounded_questions.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/skills/evaluate_grounded_questions.yaml", "output_cols": ["evaluation", "score"], }, }, @@ -407,13 +309,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_grounded_responses", - "config_path": os.path.join( - self.sdg_base, - "configs/skills/grounded_responses.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/skills/grounded_responses.yaml", "output_cols": ["response"], }, }, @@ -421,13 +317,7 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "evaluate_grounded_qa_pair", - "config_path": os.path.join( - self.sdg_base, - "configs/skills/evaluate_grounded_pair.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), + "config_path": "configs/skills/evaluate_grounded_pair.yaml", "output_cols": ["evaluation", "score"], }, }, diff --git a/src/instructlab/sdg/filterblock.py b/src/instructlab/sdg/filterblock.py index f5551b02..afb58b7b 100644 --- a/src/instructlab/sdg/filterblock.py +++ b/src/instructlab/sdg/filterblock.py @@ -9,14 +9,56 @@ logger = setup_logger(__name__) +# Note - this is not a method on the class below in order to avoid +# serializing the object itself when multi-processing is used. +# In particular, SSLContext - embedded in the OpenAI client object - +# cannot be pickled. +def _filter_by_values(samples, column, op, values, num_proc=1): + return samples.filter( + lambda x: any(op(x[column], value) for value in values), + num_proc=num_proc, + ) + + +def _map_dtype(samples, column, dtype, num_proc=1): + def convert_column(sample): + try: + sample[column] = dtype(sample[column]) + except ValueError as e: + logger.error( + "Error converting dtype: %s, filling with None to be filtered later", e + ) + sample[column] = None + return sample + + # FIXME: it appears multiprocessing map has issues with + # None columns. If we pass num_proc>1 here and the error + # case is triggered above, we get: + # ValueError: The features can't be aligned ... + # because the column is still considered a string not + # the new dtype. + num_proc = 1 + + return samples.map(convert_column, num_proc=num_proc) + + class FilterByValueBlock(Block): def __init__( - self, filter_column, filter_value, operation, convert_dtype=None, **batch_kwargs + self, + ctx, + block_name, + filter_column, + filter_value, + operation, + convert_dtype=None, + **batch_kwargs, ) -> None: """ Initializes a new instance of the FilterByValueBlock class. Parameters: + - ctx (PipelineContext): A PipelineContext object containing runtime parameters. + - block_name (str): An identifier for this block. - filter_column (str): The name of the column in the dataset to apply the filter on. - filter_value (any or list of any): The value(s) to filter by. - operation (callable): A function that takes two arguments (column value and filter value) and returns a boolean indicating whether the row should be included in the filtered dataset. @@ -26,33 +68,19 @@ def __init__( Returns: None """ - super().__init__(block_name=self.__class__.__name__) + super().__init__(ctx, block_name) self.value = filter_value if isinstance(filter_value, list) else [filter_value] self.column_name = filter_column self.operation = operation self.convert_dtype = convert_dtype self.num_procs = batch_kwargs.get("num_procs", 1) - def _convert_dtype(self, sample): - try: - sample[self.column_name] = self.convert_dtype(sample[self.column_name]) - except ValueError as e: - logger.error( - "Error converting dtype: %s, filling with None to be filtered later", e - ) - sample[self.column_name] = None - return sample - def generate(self, samples) -> Dataset: if self.convert_dtype: - samples = samples.map( - self._convert_dtype, - num_proc=self.num_procs, + samples = _map_dtype( + samples, self.column_name, self.convert_dtype, self.num_procs ) - return samples.filter( - lambda x: any( - self.operation(x[self.column_name], value) for value in self.value - ), - num_proc=self.num_procs, + return _filter_by_values( + samples, self.column_name, self.operation, self.value, self.num_procs ) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 36c6cad4..abcd6665 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -18,8 +18,6 @@ # pylint: disable=ungrouped-imports from instructlab.sdg import SDG, utils from instructlab.sdg.default_flows import ( - MODEL_FAMILY_MERLINITE, - MODEL_FAMILY_MIXTRAL, MMLUBenchFlow, SimpleFreeformSkillFlow, SimpleGroundedSkillFlow, @@ -28,7 +26,8 @@ SynthKnowledgeFlow, SynthSkillsFlow, ) -from instructlab.sdg.pipeline import Pipeline +from instructlab.sdg.llmblock import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL +from instructlab.sdg.pipeline import Pipeline, PipelineContext from instructlab.sdg.utils import models from instructlab.sdg.utils.taxonomy import ( leaf_node_to_samples, @@ -184,37 +183,25 @@ def _sdg_init(pipeline, client, model_family, model_name, num_instructions_to_ge else: raise utils.GenerateException(f"Error: pipeline ({pipeline}) is not supported.") - sdg_knowledge = SDG( - [ - Pipeline( - flow_type( - client, model_family, model_name, num_instructions_to_generate - ).get_flow() - ) - for flow_type in knowledge_flow_types - ] - ) - sdg_freeform_skill = SDG( - [ - Pipeline( - flow_type( - client, model_family, model_name, num_instructions_to_generate - ).get_flow() - ) - for flow_type in freeform_skill_flow_types - ] + ctx = PipelineContext( + client, model_family, model_name, num_instructions_to_generate ) - sdg_grounded_skill = SDG( - [ - Pipeline( - flow_type( - client, model_family, model_name, num_instructions_to_generate - ).get_flow() - ) - for flow_type in grounded_skill_flow_types - ] + + def build_pipeline(flow_types): + block_configs = [] + for flow_type in flow_types: + block_configs.extend(flow_type(ctx).get_flow()) + return Pipeline(ctx, block_configs) + + knowledge_pipeline = build_pipeline(knowledge_flow_types) + freeform_skill_pipeline = build_pipeline(freeform_skill_flow_types) + grounded_skill_pipeline = build_pipeline(grounded_skill_flow_types) + + return ( + SDG([knowledge_pipeline]), + SDG([freeform_skill_pipeline]), + SDG([grounded_skill_pipeline]), ) - return sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill # TODO - parameter removal needs to be done in sync with a CLI change. diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/llmblock.py index 4153a191..eaa58556 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/llmblock.py @@ -13,6 +13,23 @@ logger = setup_logger(__name__) +MODEL_FAMILY_MIXTRAL = "mixtral" +MODEL_FAMILY_MERLINITE = "merlinite" + +_MODEL_PROMPT_MIXTRAL = " [INST] {prompt} [/INST]" +_MODEL_PROMPT_MERLINITE = "'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n'" + +_MODEL_PROMPTS = { + MODEL_FAMILY_MIXTRAL: _MODEL_PROMPT_MIXTRAL, + MODEL_FAMILY_MERLINITE: _MODEL_PROMPT_MERLINITE, +} + + +def _get_model_prompt(model_family): + if model_family not in _MODEL_PROMPTS: + raise ValueError(f"Unknown model family: {model_family}") + return _MODEL_PROMPTS[model_family] + def server_supports_batched(client, model_id: str) -> bool: supported = getattr(client, "server_supports_batched", None) @@ -38,38 +55,36 @@ class LLMBlock(Block): # pylint: disable=too-many-instance-attributes def __init__( self, + ctx, block_name, config_path, - client, - model_id, output_cols, parser_kwargs={}, - model_prompt="{prompt}", **batch_kwargs, ) -> None: - super().__init__(block_name) + super().__init__(ctx, block_name) self.block_config = self._load_config(config_path) self.prompt_struct = ( """{system}\n{introduction}\n{principles}\n{examples}\n{generation}""" ) self.prompt_template = self.prompt_struct.format(**self.block_config) - self.client = client - self.model = model_id - self.model_prompt = model_prompt + self.model_prompt = _get_model_prompt(self.ctx.model_family) self.output_cols = output_cols self.batch_params = batch_kwargs.get("batch_kwargs", {}) self.parser_name = parser_kwargs.get("parser_name", None) self.parsing_pattern = parser_kwargs.get("parsing_pattern", None) self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None) self.defaults = { - "model": self.model, + "model": self.ctx.model_id, "temperature": 0, "max_tokens": 12000, } # Whether the LLM server supports a list of input prompts # and supports the n parameter to generate n outputs per input - self.server_supports_batched = server_supports_batched(client, model_id) + self.server_supports_batched = server_supports_batched( + self.ctx.client, self.ctx.model_id + ) def _parse(self, generated_string) -> dict: matches = {} @@ -119,14 +134,16 @@ def _generate(self, samples, **gen_kwargs) -> list: generate_args = {**self.defaults, **gen_kwargs} if self.server_supports_batched: - response = self.client.completions.create(prompt=prompts, **generate_args) + response = self.ctx.client.completions.create( + prompt=prompts, **generate_args + ) return [choice.text.strip() for choice in response.choices] n = gen_kwargs.get("n", 1) results = [] for prompt in prompts: for _ in range(n): - response = self.client.completions.create( + response = self.ctx.client.completions.create( prompt=prompt, **generate_args ) results.append(response.choices[0].text.strip()) @@ -189,24 +206,20 @@ def generate(self, samples: Dataset, **gen_kwargs) -> Dataset: class ConditionalLLMBlock(LLMBlock): def __init__( self, + ctx, block_name, config_paths, - client, - model_id, output_cols, selector_column_name, parser_kwargs={}, - model_prompt="{prompt}", **batch_kwargs, ) -> None: super().__init__( + ctx, block_name, config_paths[0][0], - client, - model_id, output_cols, parser_kwargs=parser_kwargs, - model_prompt=model_prompt, **batch_kwargs, ) self.selector_column_name = selector_column_name diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index bc570a83..93464601 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +# Standard +from importlib import resources + # Third Party from datasets import Dataset @@ -8,12 +11,25 @@ logger = setup_logger(__name__) +class PipelineContext: + def __init__( + self, client, model_family, model_id, num_instructions_to_generate + ) -> None: + self.client = client + self.model_family = model_family + self.model_id = model_id + self.num_instructions_to_generate = num_instructions_to_generate + self.sdg_base = resources.files(__package__) + + class Pipeline: - def __init__(self, chained_blocks: list) -> None: + def __init__(self, ctx, chained_blocks: list) -> None: """ Initialize the Pipeline class with a configuration dictionary. config_dict: the run config py or yaml loaded into a dictionary """ + # ctx is a PipelineContext object that supplies context configuration to every block + self.ctx = ctx # pipeline config is the run configuration that consists of the pipeline steps self.chained_blocks = chained_blocks @@ -36,7 +52,7 @@ def generate(self, dataset) -> Dataset: drop_columns = block_prop.get("drop_columns", []) gen_kwargs = block_prop.get("gen_kwargs", {}) drop_duplicates_cols = block_prop.get("drop_duplicates", False) - block = block_type(**block_config) + block = block_type(self.ctx, **block_config) logger.info("Running block: %s", block_config["block_name"]) logger.info(dataset) diff --git a/src/instructlab/sdg/utilblocks.py b/src/instructlab/sdg/utilblocks.py index db04b5a1..871b2ce8 100644 --- a/src/instructlab/sdg/utilblocks.py +++ b/src/instructlab/sdg/utilblocks.py @@ -10,10 +10,10 @@ class SamplePopulatorBlock(Block): - def __init__(self, config_paths, column_name, post_fix="", **batch_kwargs) -> None: - super().__init__( - block_name=self.__class__.__name__ - ) # Call the base class's __init__ + def __init__( + self, ctx, block_name, config_paths, column_name, post_fix="", **batch_kwargs + ) -> None: + super().__init__(ctx, block_name) self.configs = {} for config in config_paths: if post_fix: @@ -25,46 +25,69 @@ def __init__(self, config_paths, column_name, post_fix="", **batch_kwargs) -> No self.column_name = column_name self.num_procs = batch_kwargs.get("num_procs", 8) - def _generate(self, sample) -> dict: - sample = {**sample, **self.configs[sample[self.column_name]]} - return sample + # Using a static method to avoid serializing self when using multiprocessing + @staticmethod + def _map_populate(samples, configs, column_name, num_proc=1): + def populate(sample): + return {**sample, **configs[sample[column_name]]} + + return samples.map(populate, num_proc) def generate(self, samples) -> Dataset: - samples = samples.map(self._generate, num_proc=self.num_procs) - return samples + return self._map_populate_samples( + samples, self.configs, self.column_name, self.num_procs + ) class SelectorBlock(Block): - def __init__(self, choice_map, choice_col, output_col, **batch_kwargs) -> None: - super().__init__(block_name=self.__class__.__name__) + def __init__( + self, ctx, block_name, choice_map, choice_col, output_col, **batch_kwargs + ) -> None: + super().__init__(ctx, block_name) self.choice_map = choice_map self.choice_col = choice_col self.output_col = output_col self.num_procs = batch_kwargs.get("num_procs", 8) - def _generate(self, sample) -> dict: - sample[self.output_col] = sample[self.choice_map[sample[self.choice_col]]] - return sample + # Using a static method to avoid serializing self when using multiprocessing + @staticmethod + def _map_select_choice(samples, choice_map, choice_col, output_col, num_proc=1): + def select_choice(sample) -> dict: + sample[output_col] = sample[choice_map[sample[choice_col]]] + return sample + + return samples.map(select_choice, num_proc) def generate(self, samples: Dataset) -> Dataset: - samples = samples.map(self._generate, num_proc=self.num_procs) - return samples + return self._map_select_choice( + samples, + self.choice_map, + self.choice_col, + self.output_col, + self.num_procs, + ) class CombineColumnsBlock(Block): - def __init__(self, columns, output_col, separator="\n\n", **batch_kwargs) -> None: - super().__init__(block_name=self.__class__.__name__) + def __init__( + self, ctx, block_name, columns, output_col, separator="\n\n", **batch_kwargs + ) -> None: + super().__init__(ctx, block_name) self.columns = columns self.output_col = output_col self.separator = separator self.num_procs = batch_kwargs.get("num_procs", 8) - def _generate(self, sample) -> dict: - sample[self.output_col] = self.separator.join( - [sample[col] for col in self.columns] - ) - return sample + # Using a static method to avoid serializing self when using multiprocessing + @staticmethod + def _map_combine(samples, columns, output_col, separator, num_proc=1): + def combine(sample): + sample[output_col] = separator.join([sample[col] for col in columns]) + return sample + + return samples.map(combine, num_proc=num_proc) def generate(self, samples: Dataset) -> Dataset: - samples = samples.map(self._generate, num_proc=self.num_procs) - return samples + return self._map_combine( + samples, self.columns, self.output_col, self.separator, self.num_procs + ) diff --git a/tests/test_filterblock.py b/tests/test_filterblock.py index 7b8b1ce7..5e00c80b 100644 --- a/tests/test_filterblock.py +++ b/tests/test_filterblock.py @@ -8,17 +8,22 @@ # First Party from instructlab.sdg.filterblock import FilterByValueBlock +from instructlab.sdg.pipeline import PipelineContext class TestFilterByValueBlock(unittest.TestCase): def setUp(self): self.block = FilterByValueBlock( + PipelineContext(None, None, None, None), + block_name="filter_by_age", filter_column="age", filter_value=30, operation=operator.eq, convert_dtype=int, ) self.block_with_list = FilterByValueBlock( + PipelineContext(None, None, None, None), + block_name="filter_by_ages", filter_column="age", filter_value=[30, 35], operation=operator.eq,