diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 51e54418..348e26a4 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -1,3 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 + +# NOTE: This package imports Torch and other heavy packages. +__all__ = ( + "Block", + "CombineColumnsBlock", + "ConditionalLLMBlock", + "EmptyDatasetError", + "FilterByValueBlock", + "FilterByValueBlockError", + "GenerateException", + "ImportBlock", + "LLMBlock", + "Pipeline", + "PipelineConfigParserError", + "PipelineContext", + "SamplePopulatorBlock", + "SelectorBlock", + "SDG", + "SIMPLE_PIPELINES_PACKAGE", + "FULL_PIPELINES_PACKAGE", + "generate_data", +) + # Local +from .block import Block +from .filterblock import FilterByValueBlock, FilterByValueBlockError +from .generate_data import generate_data +from .importblock import ImportBlock +from .llmblock import ConditionalLLMBlock, LLMBlock +from .pipeline import ( + FULL_PIPELINES_PACKAGE, + SIMPLE_PIPELINES_PACKAGE, + EmptyDatasetError, + Pipeline, + PipelineConfigParserError, + PipelineContext, +) from .sdg import SDG +from .utilblocks import CombineColumnsBlock, SamplePopulatorBlock, SelectorBlock +from .utils import GenerateException +from .utils.taxonomy import TaxonomyReadingException diff --git a/src/instructlab/sdg/block.py b/src/instructlab/sdg/block.py index 75b0a4e8..dfde5270 100644 --- a/src/instructlab/sdg/block.py +++ b/src/instructlab/sdg/block.py @@ -14,6 +14,7 @@ logger = setup_logger(__name__) +# This is part of the public API. class Block(ABC): def __init__(self, ctx, pipe, block_name: str) -> None: self.ctx = ctx diff --git a/src/instructlab/sdg/filterblock.py b/src/instructlab/sdg/filterblock.py index 3cc7b427..c9e3f7e9 100644 --- a/src/instructlab/sdg/filterblock.py +++ b/src/instructlab/sdg/filterblock.py @@ -12,6 +12,7 @@ logger = setup_logger(__name__) +# This is part of the public API. class FilterByValueBlockError(Exception): """An exception raised by the FilterByValue block.""" @@ -73,6 +74,7 @@ def convert_column(sample): return samples.map(convert_column, num_proc=num_proc) +# This is part of the public API. class FilterByValueBlock(Block): def __init__( self, diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 7a926a5a..0e3c42eb 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -17,7 +17,6 @@ # First Party # pylint: disable=ungrouped-imports -from instructlab.sdg import SDG, utils from instructlab.sdg.llmblock import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL from instructlab.sdg.pipeline import ( FULL_PIPELINES_PACKAGE, @@ -25,7 +24,8 @@ Pipeline, PipelineContext, ) -from instructlab.sdg.utils import models +from instructlab.sdg.sdg import SDG +from instructlab.sdg.utils import GenerateException, models from instructlab.sdg.utils.taxonomy import ( leaf_node_to_samples, read_taxonomy_leaf_nodes, @@ -48,7 +48,7 @@ def _get_question(logger, synth_example): return synth_example["question"] if not synth_example.get("output"): - raise utils.GenerateException( + raise GenerateException( f"Error: output not found in synth_example: {synth_example}" ) @@ -64,7 +64,7 @@ def _get_response(logger, synth_example): return synth_example["response"] if "output" not in synth_example: - raise utils.GenerateException( + raise GenerateException( f"Error: output not found in synth_example: {synth_example}" ) @@ -173,12 +173,12 @@ def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_gene else: # Validate that pipeline is a valid directory and that it contains the required files if not os.path.exists(pipeline): - raise utils.GenerateException( + raise GenerateException( f"Error: pipeline directory ({pipeline}) does not exist." ) for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]: if not os.path.exists(os.path.join(pipeline, file)): - raise utils.GenerateException( + raise GenerateException( f"Error: pipeline directory ({pipeline}) does not contain {file}." ) @@ -198,6 +198,7 @@ def load_pipeline(yaml_basename): ) +# This is part of the public API, and used by instructlab. # TODO - parameter removal needs to be done in sync with a CLI change. # pylint: disable=unused-argument def generate_data( @@ -226,7 +227,7 @@ def generate_data( tls_client_key: Optional[str] = None, tls_client_passwd: Optional[str] = None, pipeline: Optional[str] = "simple", -): +) -> None: """Generate data for training and testing a model. This currently serves as the primary interface from the `ilab` CLI to the `sdg` library. @@ -246,11 +247,11 @@ def generate_data( os.mkdir(output_dir) if not (taxonomy and os.path.exists(taxonomy)): - raise utils.GenerateException(f"Error: taxonomy ({taxonomy}) does not exist.") + raise GenerateException(f"Error: taxonomy ({taxonomy}) does not exist.") leaf_nodes = read_taxonomy_leaf_nodes(taxonomy, taxonomy_base, yaml_rules) if not leaf_nodes: - raise utils.GenerateException("Error: No new leaf nodes found in the taxonomy.") + raise GenerateException("Error: No new leaf nodes found in the taxonomy.") name = Path(model_name).stem # Just in case it is a file path date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_") @@ -301,7 +302,7 @@ def generate_data( samples = leaf_node_to_samples(leaf_node, server_ctx_size, chunk_word_count) if not samples: - raise utils.GenerateException("Error: No samples found in leaf node.") + raise GenerateException("Error: No samples found in leaf node.") if samples[0].get("document"): sdg = sdg_knowledge diff --git a/src/instructlab/sdg/importblock.py b/src/instructlab/sdg/importblock.py index 5fa479b8..8efb2550 100644 --- a/src/instructlab/sdg/importblock.py +++ b/src/instructlab/sdg/importblock.py @@ -10,6 +10,7 @@ logger = setup_logger(__name__) +# This is part of the public API. class ImportBlock(Block): def __init__( self, diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/llmblock.py index 8d0b26c4..f96d92ee 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/llmblock.py @@ -50,6 +50,7 @@ def server_supports_batched(client, model_id: str) -> bool: return supported +# This is part of the public API. # pylint: disable=dangerous-default-value class LLMBlock(Block): # pylint: disable=too-many-instance-attributes @@ -212,6 +213,7 @@ def generate(self, samples: Dataset, **gen_kwargs) -> Dataset: return Dataset.from_list(new_data) +# This is part of the public API. class ConditionalLLMBlock(LLMBlock): def __init__( self, diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 40541bbf..2673d6c4 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -14,10 +14,12 @@ logger = setup_logger(__name__) +# This is part of the public API. class EmptyDatasetError(Exception): pass +# This is part of the public API. class PipelineContext: def __init__( self, client, model_family, model_id, num_instructions_to_generate @@ -30,6 +32,7 @@ def __init__( self.num_procs = 8 +# This is part of the public API. class Pipeline: def __init__(self, ctx, config_path, chained_blocks: list) -> None: """ @@ -113,6 +116,7 @@ def _lookup_block_type(block_type): _PIPELINE_CONFIG_PARSER_MINOR = 0 +# This is part of the public API. class PipelineConfigParserError(Exception): """An exception raised while parsing a pipline config file.""" @@ -141,5 +145,6 @@ def _parse_pipeline_config_file(pipeline_yaml): return content["blocks"] +# This is part of the public API. SIMPLE_PIPELINES_PACKAGE = "instructlab.sdg.pipelines.simple" FULL_PIPELINES_PACKAGE = "instructlab.sdg.pipelines.full" diff --git a/src/instructlab/sdg/sdg.py b/src/instructlab/sdg/sdg.py index c3bce90f..7bfba702 100644 --- a/src/instructlab/sdg/sdg.py +++ b/src/instructlab/sdg/sdg.py @@ -6,6 +6,7 @@ from .pipeline import Pipeline +# This is part of the public API. class SDG: def __init__(self, pipelines: list[Pipeline]) -> None: self.pipelines = pipelines diff --git a/src/instructlab/sdg/utilblocks.py b/src/instructlab/sdg/utilblocks.py index 02b536f5..2bd1392b 100644 --- a/src/instructlab/sdg/utilblocks.py +++ b/src/instructlab/sdg/utilblocks.py @@ -9,6 +9,7 @@ logger = setup_logger(__name__) +# This is part of the public API. class SamplePopulatorBlock(Block): def __init__( self, ctx, pipe, block_name, config_paths, column_name, post_fix="" @@ -38,6 +39,7 @@ def generate(self, samples) -> Dataset: ) +# This is part of the public API. class SelectorBlock(Block): def __init__( self, ctx, pipe, block_name, choice_map, choice_col, output_col @@ -66,6 +68,7 @@ def generate(self, samples: Dataset) -> Dataset: ) +# This is part of the public API. class CombineColumnsBlock(Block): def __init__( self, ctx, pipe, block_name, columns, output_col, separator="\n\n" diff --git a/src/instructlab/sdg/utils/__init__.py b/src/instructlab/sdg/utils/__init__.py index 5faacb12..e3bd0771 100644 --- a/src/instructlab/sdg/utils/__init__.py +++ b/src/instructlab/sdg/utils/__init__.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# This is part of the public API, and used by instructlab +# This is part of the public API, and used by instructlab. class GenerateException(Exception): """An exception raised during generate step.""" diff --git a/src/instructlab/sdg/utils/taxonomy.py b/src/instructlab/sdg/utils/taxonomy.py index d6f6441b..c8389247 100644 --- a/src/instructlab/sdg/utils/taxonomy.py +++ b/src/instructlab/sdg/utils/taxonomy.py @@ -32,6 +32,7 @@ """ +# This is part of the public API. class TaxonomyReadingException(Exception): """An exception raised during reading of the taxonomy."""