Skip to content

Commit

Permalink
Merge pull request #127 from russellb/empty-dataset-exception
Browse files Browse the repository at this point in the history
pipeline: Fail explicitly on an empty dataset
  • Loading branch information
russellb authored Jul 13, 2024
2 parents 13df1c9 + 8f03a1f commit e636680
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
10 changes: 10 additions & 0 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
logger = setup_logger(__name__)


class EmptyDatasetError(Exception):
pass


class PipelineContext:
def __init__(
self, client, model_family, model_id, num_instructions_to_generate
Expand Down Expand Up @@ -72,6 +76,12 @@ def generate(self, dataset) -> Dataset:

dataset = block.generate(dataset, **gen_kwargs)

# If at any point we end up with an empty data set, the pipeline has failed
if len(dataset) == 0:
raise EmptyDatasetError(
f"Pipeline stopped: Empty dataset after running block: {block_name}"
)

drop_columns_in_ds = [e for e in drop_columns if e in dataset.column_names]
if drop_columns:
dataset = dataset.remove_columns(drop_columns_in_ds)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_default_pipeline_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def _noop_generate(self, samples, **gen_kwargs):
@patch.object(SamplePopulatorBlock, "generate", _noop_generate)
@patch.object(SelectorBlock, "generate", _noop_generate)
@patch("instructlab.sdg.llmblock.server_supports_batched", lambda c, m: True)
@patch.object(Pipeline, "_drop_duplicates", lambda self, dataset, cols: dataset)
class TestDefaultPipelineConfigs(unittest.TestCase):
def setUp(self):
self._yaml_files = [
Expand All @@ -49,5 +50,5 @@ def test_pipeline_from_config(self):
)
for pipeline_yaml in self._yaml_files:
pipeline = Pipeline.from_file(ctx, pipeline_yaml)
output = pipeline.generate(Dataset.from_list([]))
output = pipeline.generate(Dataset.from_list([{"test": "test"}]))
self.assertIsNotNone(output)

0 comments on commit e636680

Please sign in to comment.