Skip to content

Commit

Permalink
Add ImportBlock to allow extending existing pipelines
Browse files Browse the repository at this point in the history
This is to enable the common case of a custom pipeline that
extends an existing pipeline, commonly either by prepending
or appending to the existing pipeline.

The format looks like e.g.:

```
version: "1.0"
blocks:
- <some blocks>
- name: import_child
  type: ImportBlock
  config:
    path: pipelines/full/knowledge.yaml
- <some more blocks>
```

Signed-off-by: Mark McLoughlin <[email protected]>
  • Loading branch information
markmc committed Jul 12, 2024
1 parent eb2719f commit 46f16c6
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 1 deletion.
34 changes: 34 additions & 0 deletions src/instructlab/sdg/importblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# Third Party
from datasets import Dataset

# Local
from . import pipeline
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)


class ImportBlock(Block):
def __init__(
self,
ctx,
block_name,
path,
) -> None:
"""
ImportBlock imports a chain of blocks from another pipeline config file.
Parameters:
- ctx (PipelineContext): A PipelineContext object containing runtime parameters.
- block_name (str): An identifier for this block.
- path (str): A path (absolute, or relative to the instructlab.sdg package) to a pipeline config file.
"""
super().__init__(ctx, block_name)
self.path = path
self.pipeline = pipeline.Pipeline.from_file(self.ctx, self.path)

def generate(self, samples) -> Dataset:
logger.info("ImportBlock chaining to blocks from {self.path}")
return self.pipeline.generate(samples)
3 changes: 2 additions & 1 deletion src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import yaml

# Local
from . import filterblock, llmblock, utilblocks
from . import filterblock, importblock, llmblock, utilblocks
from .logger_config import setup_logger

logger = setup_logger(__name__)
Expand Down Expand Up @@ -85,6 +85,7 @@ def generate(self, dataset) -> Dataset:
"CombineColumnsBlock": utilblocks.CombineColumnsBlock,
"ConditionalLLMBlock": llmblock.ConditionalLLMBlock,
"FilterByValueBlock": filterblock.FilterByValueBlock,
"ImportBlock": importblock.ImportBlock,
"LLMBlock": llmblock.LLMBlock,
"SamplePopulatorBlock": utilblocks.SamplePopulatorBlock,
"SelectorBlock": utilblocks.SelectorBlock,
Expand Down
103 changes: 103 additions & 0 deletions tests/test_importblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Standard
from unittest.mock import MagicMock, patch
import os
import tempfile
import unittest

# Third Party
from datasets import Dataset, Features, Value

# First Party
from instructlab.sdg.importblock import ImportBlock
from instructlab.sdg.pipeline import Pipeline


class TestImportBlockWithMockPipeline(unittest.TestCase):
@patch("instructlab.sdg.pipeline.Pipeline")
def setUp(self, mock_pipeline):
self.ctx = MagicMock()
self.block_name = "test_block"
self.path = "/path/to/config"
self.mock_pipeline = mock_pipeline
self.import_block = ImportBlock(self.ctx, self.block_name, self.path)
self.dataset = Dataset.from_dict({})

def test_initialization(self):
self.assertEqual(self.import_block.block_name, self.block_name)
self.assertEqual(self.import_block.path, self.path)
self.mock_pipeline.from_file.assert_called_once_with(self.ctx, self.path)

def test_generate(self):
self.mock_pipeline.from_file.return_value.generate.return_value = self.dataset
samples = self.import_block.generate(self.dataset)
self.mock_pipeline.from_file.return_value.generate.assert_called_once_with(
samples
)
self.assertEqual(samples, self.dataset)


_CHILD_YAML = """\
version: "1.0"
blocks:
- name: greater_than_thirty
type: FilterByValueBlock
config:
filter_column: age
filter_value: 30
operation: gt
convert_dtype: int
"""


_PARENT_YAML_FMT = """\
version: "1.0"
blocks:
- name: forty_or_under
type: FilterByValueBlock
config:
filter_column: age
filter_value: 40
operation: le
convert_dtype: int
- name: import_child
type: ImportBlock
config:
path: %s
- name: big_bdays
type: FilterByValueBlock
config:
filter_column: age
filter_value:
- 30
- 40
operation: eq
convert_dtype: int
"""


class TestImportBlockWithFilterByValue(unittest.TestCase):
def setUp(self):
self.ctx = MagicMock()
self.ctx.num_procs = 1
self.child_yaml = self._write_tmp_yaml(_CHILD_YAML)
self.parent_yaml = self._write_tmp_yaml(_PARENT_YAML_FMT % self.child_yaml)
self.dataset = Dataset.from_dict(
{"age": ["25", "30", "35", "40", "45"]},
features=Features({"age": Value("string")}),
)

def tearDown(self):
os.remove(self.parent_yaml)
os.remove(self.child_yaml)

def _write_tmp_yaml(self, content):
tmp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".yaml")
tmp_file.write(content)
tmp_file.close()
return tmp_file.name

def test_generate(self):
pipeline = Pipeline.from_file(self.ctx, self.parent_yaml)
filtered_dataset = pipeline.generate(self.dataset)
self.assertEqual(len(filtered_dataset), 1)
self.assertEqual(filtered_dataset["age"], [40])

0 comments on commit 46f16c6

Please sign in to comment.