From 8f03a1fde5f126a73d8e68a02ba5f566a3daa6ae Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 11 Jul 2024 22:52:48 -0400 Subject: [PATCH] pipeline: Fail explicitly on an empty dataset If at any point generate() on a block returns an empty dataset, this is a failure condition. Go ahead and raise an exception right away in that case. This change was originally a subset of this commit: https://github.com/aakankshaduggal/sdg/pull/3/commits/256335e4a4f2e6f8642a6df279f16d9d85645445 Co-authored-by: shiv Co-authored-by: Aakanksha Duggal Co-authored-by: Kai Xu Signed-off-by: Russell Bryant --- src/instructlab/sdg/pipeline.py | 10 ++++++++++ tests/test_default_pipeline_configs.py | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 3ee08306..40541bbf 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -14,6 +14,10 @@ logger = setup_logger(__name__) +class EmptyDatasetError(Exception): + pass + + class PipelineContext: def __init__( self, client, model_family, model_id, num_instructions_to_generate @@ -72,6 +76,12 @@ def generate(self, dataset) -> Dataset: dataset = block.generate(dataset, **gen_kwargs) + # If at any point we end up with an empty data set, the pipeline has failed + if len(dataset) == 0: + raise EmptyDatasetError( + f"Pipeline stopped: Empty dataset after running block: {block_name}" + ) + drop_columns_in_ds = [e for e in drop_columns if e in dataset.column_names] if drop_columns: dataset = dataset.remove_columns(drop_columns_in_ds) diff --git a/tests/test_default_pipeline_configs.py b/tests/test_default_pipeline_configs.py index 211cf4de..676535dd 100644 --- a/tests/test_default_pipeline_configs.py +++ b/tests/test_default_pipeline_configs.py @@ -28,6 +28,7 @@ def _noop_generate(self, samples, **gen_kwargs): @patch.object(SamplePopulatorBlock, "generate", _noop_generate) @patch.object(SelectorBlock, "generate", _noop_generate) @patch("instructlab.sdg.llmblock.server_supports_batched", lambda c, m: True) +@patch.object(Pipeline, "_drop_duplicates", lambda self, dataset, cols: dataset) class TestDefaultPipelineConfigs(unittest.TestCase): def setUp(self): self._yaml_files = [ @@ -49,5 +50,5 @@ def test_pipeline_from_config(self): ) for pipeline_yaml in self._yaml_files: pipeline = Pipeline.from_file(ctx, pipeline_yaml) - output = pipeline.generate(Dataset.from_list([])) + output = pipeline.generate(Dataset.from_list([{"test": "test"}])) self.assertIsNotNone(output)