diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 348e26a4..06df1200 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -12,6 +12,7 @@ "ImportBlock", "LLMBlock", "Pipeline", + "PipelineBlockError", "PipelineConfigParserError", "PipelineContext", "SamplePopulatorBlock", @@ -33,6 +34,7 @@ SIMPLE_PIPELINES_PACKAGE, EmptyDatasetError, Pipeline, + PipelineBlockError, PipelineConfigParserError, PipelineContext, ) diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 05044dc0..8b006435 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Standard from importlib import resources +from typing import Optional import os.path # Third Party @@ -9,6 +10,7 @@ # Local from . import filterblock, importblock, llmblock, utilblocks +from .block import Block from .logger_config import setup_logger logger = setup_logger(__name__) @@ -32,6 +34,33 @@ def __init__( self.num_procs = 8 +# This is part of the public API. +class PipelineBlockError(Exception): + """A PipelineBlockError occurs when a block generates an exception during + generation. It contains information about which block failed and why. + """ + + def __init__( + self, + exception: Exception, + *, + block: Optional[Block] = None, + block_name: Optional[str] = None, + block_type: Optional[str] = None, + ): + self.exception = exception + self.block = block + self.block_name = block_name or (block.block_name if block else None) + self.block_type = block_type or (block.__class__.__name__ if block else None) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.block_type}/{self.block_name}): {self.exception_message}" + + @property + def exception_message(self) -> str: + return str(self.exception) + + # This is part of the public API. class Pipeline: def __init__(self, ctx, config_path, chained_blocks: list) -> None: @@ -67,17 +96,28 @@ def generate(self, dataset) -> Dataset: dataset: the input dataset """ for block_prop in self.chained_blocks: - block_name = block_prop["name"] - block_type = _lookup_block_type(block_prop["type"]) - block_config = block_prop["config"] - drop_columns = block_prop.get("drop_columns", []) - drop_duplicates_cols = block_prop.get("drop_duplicates", False) - block = block_type(self.ctx, self, block_name, **block_config) - - logger.info("Running block: %s", block_name) - logger.info(dataset) - - dataset = block.generate(dataset) + # Initialize arguments for error handling to None + block, block_name, block_type = None, None, None + try: + # Parse and instantiate the block + block_name = block_prop["name"] + block_type = _lookup_block_type(block_prop["type"]) + block_config = block_prop["config"] + drop_columns = block_prop.get("drop_columns", []) + drop_duplicates_cols = block_prop.get("drop_duplicates", False) + block = block_type(self.ctx, self, block_name, **block_config) + logger.info("Running block: %s", block_name) + logger.info(dataset) + + # Execute the block and wrap errors with the block name/type + dataset = block.generate(dataset) + except Exception as err: + raise PipelineBlockError( + exception=err, + block=block, + block_name=block_name, + block_type=block_type, + ) from err # If at any point we end up with an empty data set, the pipeline has failed if len(dataset) == 0: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 00000000..443eda26 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,103 @@ +""" +Unit tests for common Pipeline functionality +""" + +# Standard +from unittest import mock + +# Third Party +from datasets import Dataset +import pytest + +# First Party +from instructlab.sdg.block import Block +from instructlab.sdg.pipeline import Pipeline, PipelineBlockError + + +def test_pipeline_named_errors_match_type(): + """Validate that a PipelineBlockError is raised to wrap exceptions raised + in a Block's generate method + """ + mock_dataset = ["not empty"] + working_block = mock.MagicMock() + working_block().generate.return_value = mock_dataset + failure_block = mock.MagicMock() + failure_block.__name__ = "BadBlock" + failure_exc = RuntimeError("Oh no!") + failure_block().generate = mock.MagicMock(side_effect=failure_exc) + pipe_cfg = [ + {"name": "I work", "type": "working", "config": {}}, + {"name": "I don't", "type": "failure", "config": {}}, + ] + with mock.patch( + "instructlab.sdg.pipeline._block_types", + { + "working": working_block, + "failure": failure_block, + }, + ): + pipe = Pipeline(None, None, pipe_cfg) + with pytest.raises(PipelineBlockError) as exc_ctx: + pipe.generate(None) + + assert exc_ctx.value.__cause__ is failure_exc + assert exc_ctx.value.exception is failure_exc + assert exc_ctx.value.block is failure_block() + + +def test_pipeline_config_error_handling(): + """Validate that a PipelineBlockError is raised when block config is + incorrect + """ + pipe_cfg = [ + {"name_not_there": "I work", "type": "working", "config": {}}, + {"name": "I don't", "type": "failure", "config": {}}, + ] + pipe = Pipeline(None, None, pipe_cfg) + with pytest.raises(PipelineBlockError) as exc_ctx: + pipe.generate(None) + + assert isinstance(exc_ctx.value.__cause__, KeyError) + + +def test_block_generation_error_properties_from_block(): + """Make sure the PipelineBlockError exposes its properties and string form + correctly when pulled from a Block instance + """ + + class TestBlock(Block): + def generate(self, dataset: Dataset) -> Dataset: + return dataset + + block_name = "my-block" + block = TestBlock(None, None, block_name) + inner_err = TypeError("Not the right type") + gen_err = PipelineBlockError(inner_err, block=block) + assert gen_err.block is block + assert gen_err.exception is inner_err + assert gen_err.block_name is block_name + assert gen_err.block_type == TestBlock.__name__ + assert ( + str(gen_err) + == f"{PipelineBlockError.__name__}({TestBlock.__name__}/{block_name}): {inner_err}" + ) + + +def test_block_generation_error_properties_from_strings(): + """Make sure the PipelineBlockError exposes its properties and string form + correctly when pulled from strings + """ + inner_err = TypeError("Not the right type") + block_name = "my-block" + block_type = "TestBlock" + gen_err = PipelineBlockError( + inner_err, block_name=block_name, block_type=block_type + ) + assert gen_err.block is None + assert gen_err.exception is inner_err + assert gen_err.block_name is block_name + assert gen_err.block_type == block_type + assert ( + str(gen_err) + == f"{PipelineBlockError.__name__}({block_type}/{block_name}): {inner_err}" + )