diff --git a/src/instructlab/sdg/filterblock.py b/src/instructlab/sdg/filterblock.py index 1090ce79..1ae3e5b9 100644 --- a/src/instructlab/sdg/filterblock.py +++ b/src/instructlab/sdg/filterblock.py @@ -178,9 +178,13 @@ def generate(self, samples) -> Dataset: samples, self.column_name, self.dtype, - self.ctx.num_procs, + self.ctx.dataset_num_procs, ) return _filter_by_values( - samples, self.column_name, self.operation, self.value, self.ctx.num_procs + samples, + self.column_name, + self.operation, + self.value, + self.ctx.dataset_num_procs, ) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index d654dbfd..cfbb784e 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -173,7 +173,15 @@ def _check_pipeline_dir(pipeline): ) -def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_generate): +def _sdg_init( + pipeline: Pipeline, + client: openai.OpenAI, + model_family: str, + model_id: str, + num_instructions_to_generate: int, + batch_num_workers: Optional[int], + batch_size: Optional[int], +): pipeline_pkg = None # Search for the pipeline in User and Site data directories @@ -200,7 +208,18 @@ def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_gene ) _check_pipeline_dir(pipeline) - ctx = PipelineContext(client, model_family, model_id, num_instructions_to_generate) + extra_kwargs = {} + if batch_size is not None: + extra_kwargs["batch_size"] = batch_size + extra_kwargs["batch_num_workers"] = batch_num_workers + + ctx = PipelineContext( + client=client, + model_family=model_family, + model_id=model_id, + num_instructions_to_generate=num_instructions_to_generate, + **extra_kwargs, + ) def load_pipeline(yaml_basename): if pipeline_pkg: @@ -227,8 +246,6 @@ def generate_data( api_key: Optional[str] = None, model_family: Optional[str] = None, model_name: Optional[str] = None, - # TODO - not used -- when batching is enabled, this is relevant. - # Right now the code hard codes 8 cpus for batching num_cpus: Optional[int] = None, num_instructions_to_generate: Optional[int] = 30, taxonomy: Optional[str] = None, @@ -247,6 +264,7 @@ def generate_data( tls_client_key: Optional[str] = None, tls_client_passwd: Optional[str] = None, pipeline: Optional[str] = "simple", + batch_size: Optional[int] = None, ) -> None: """Generate data for training and testing a model. @@ -264,6 +282,10 @@ def generate_data( """ generate_start = time.time() + # FIXME: remove this when ilab knows to pass batch_size=0 with llama.cpp + if batch_size is None: + batch_size = 0 + if not os.path.exists(output_dir): os.mkdir(output_dir) @@ -302,15 +324,14 @@ def generate_data( else: model_family = MODEL_FAMILY_MERLINITE - # TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI - # about whether we can turn this on (whether vllm is used or not) - sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init( pipeline, client, model_family, model_name, num_instructions_to_generate, + batch_size=batch_size, + batch_num_workers=num_cpus, ) if console_output: diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index 8b006435..c08b5b94 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from importlib import resources -from typing import Optional +from typing import Iterable, Optional +import math import os.path # Third Party -from datasets import Dataset +from datasets import Dataset, concatenate_datasets +from openai import OpenAI import yaml # Local @@ -22,16 +26,47 @@ class EmptyDatasetError(Exception): # This is part of the public API. -class PipelineContext: - def __init__( - self, client, model_family, model_id, num_instructions_to_generate - ) -> None: - self.client = client - self.model_family = model_family - self.model_id = model_id - self.num_instructions_to_generate = num_instructions_to_generate - # FIXME: base this on the available number of CPUs - self.num_procs = 8 +@dataclass +class PipelineContext: # pylint: disable=too-many-instance-attributes + """ + A PipelineContext holds the common attributes needed between blocks in a + pipeline + + client: The OpenAI client handle. + model_id: The ID of the teacher model to be used for client calls. + model_family: The family identifier for the model being updated. + num_instructions_to_generate: The total number of instructions the user + wants to generate during this run. + batch_size: The size of the dataset batches for parallel generation. Set to + 0 to disable batching. + batch_num_workers: The number of worker threads/processes to maintain in the + central executor pool. + dataset_num_procs: The number of processes to use when performing parallel + map operations on individual datasets. + """ + + # The default batch size of 8 has been determined as a good default for + # standard instructlab workloads when running with vllm batching. + DEFAULT_BATCH_SIZE = 8 + + # The default number of processes to use when performing parallel operations + # on individual datasets + DEFAULT_DATASET_NUM_PROCS = 8 + + client: OpenAI + model_family: str + model_id: str + num_instructions_to_generate: int + dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS + batch_size: int = DEFAULT_BATCH_SIZE + batch_num_workers: Optional[int] = None + + @property + def batching_enabled(self) -> bool: + """Batching is enabled IFF the batch size is specified and the number of + workers is not set explicitly to 1 + """ + return self.batch_size > 0 and self.batch_num_workers != 1 # This is part of the public API. @@ -63,7 +98,12 @@ def exception_message(self) -> str: # This is part of the public API. class Pipeline: - def __init__(self, ctx, config_path, chained_blocks: list) -> None: + def __init__( + self, + ctx: PipelineContext, + config_path: str, + chained_blocks: list[dict], + ) -> None: """ Initialize the Pipeline class with a configuration dictionary. config_dict: the run config py or yaml loaded into a dictionary @@ -81,20 +121,40 @@ def from_file(cls, ctx, pipeline_yaml): pipeline_yaml = os.path.join(resources.files(__package__), pipeline_yaml) return cls(ctx, pipeline_yaml, _parse_pipeline_config_file(pipeline_yaml)) - def _drop_duplicates(self, dataset, cols): - """ - Drop duplicates from the dataset based on the columns provided. - """ - df = dataset.to_pandas() - df = df.drop_duplicates(subset=cols).reset_index(drop=True) - ds = Dataset.from_pandas(df) - return ds - def generate(self, dataset) -> Dataset: """ Generate the dataset by running the pipeline steps. dataset: the input dataset """ + # If not batching, simply delegate to _generate_single + if not self.ctx.batching_enabled: + logger.info("Running pipeline single-threaded") + return self._generate_single(dataset) + + # Otherwise, split the dataset into batches and run each batch as a + # future in the thread pool + logger.info( + "Running pipeline with multi-threaded batching. Using %s workers for batches of size %s", + self.ctx.batch_num_workers, + self.ctx.batch_size, + ) + input_splits = self._split_dataset(dataset) + with ThreadPoolExecutor(max_workers=self.ctx.batch_num_workers) as executor: + futures = [ + executor.submit(self._generate_single, input_split) + for input_split in input_splits + ] + + # Collect the results of each batch as they finish. This needs to + # wait for them all, so the order of waiting doesn't matter + output_splits = [future.result() for future in futures] + + return concatenate_datasets(output_splits) + + ## Implementation Details ## + + def _generate_single(self, dataset) -> Dataset: + """Generate a single dataset by running the pipeline steps.""" for block_prop in self.chained_blocks: # Initialize arguments for error handling to None block, block_name, block_type = None, None, None @@ -134,6 +194,39 @@ def generate(self, dataset) -> Dataset: return dataset + def _drop_duplicates(self, dataset, cols): + """ + Drop duplicates from the dataset based on the columns provided. + """ + df = dataset.to_pandas() + df = df.drop_duplicates(subset=cols).reset_index(drop=True) + ds = Dataset.from_pandas(df) + return ds + + def _split_dataset(self, dataset: Dataset) -> list[Dataset]: + """Split the dataset into smaller batches.""" + assert ( + self.ctx.batch_size is not None + ), "Programming Error: Should not call _split_dataset if batching disabled" + total_size = len(dataset) + num_batches = math.ceil(total_size / self.ctx.batch_size) + batches = [ + dataset.select(self._get_batch_indices(i, total_size)) + for i in range(num_batches) + ] + return batches + + def _get_batch_indices(self, batch_index: int, total_size: int) -> Iterable[int]: + assert ( + self.ctx.batch_size is not None + ), "Programming Error: Should not call _get_batch_indices if batching disabled" + return range( + # Start index offset by the batch size + batch_index * self.ctx.batch_size, + # End index is the next batch offset or the end of the dataset + min((batch_index + 1) * self.ctx.batch_size, total_size), + ) + _block_types = { "CombineColumnsBlock": utilblocks.CombineColumnsBlock, diff --git a/src/instructlab/sdg/utilblocks.py b/src/instructlab/sdg/utilblocks.py index 30ac90f2..f0e93cf1 100644 --- a/src/instructlab/sdg/utilblocks.py +++ b/src/instructlab/sdg/utilblocks.py @@ -35,7 +35,7 @@ def populate(sample): def generate(self, samples) -> Dataset: return self._map_populate( - samples, self.configs, self.column_name, self.ctx.num_procs + samples, self.configs, self.column_name, self.ctx.dataset_num_procs ) @@ -64,7 +64,7 @@ def generate(self, samples: Dataset) -> Dataset: self.choice_map, self.choice_col, self.output_col, - self.ctx.num_procs, + self.ctx.dataset_num_procs, ) @@ -89,5 +89,9 @@ def combine(sample): def generate(self, samples: Dataset) -> Dataset: return self._map_combine( - samples, self.columns, self.output_col, self.separator, self.ctx.num_procs + samples, + self.columns, + self.output_col, + self.separator, + self.ctx.dataset_num_procs, ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..cb0c308a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,48 @@ +""" +Common fixtures and testing utilities +""" + +# Standard +from unittest import mock + +# Third Party +from datasets import Dataset +import pytest + +# First Party +from instructlab.sdg.pipeline import PipelineContext + + +def get_ctx(**kwargs) -> PipelineContext: + kwargs.setdefault("client", mock.MagicMock()) + kwargs.setdefault("model_family", "test") + kwargs.setdefault("model_id", "test-model") + kwargs.setdefault("num_instructions_to_generate", 10) + kwargs.setdefault("dataset_num_procs", 1) + return PipelineContext(**kwargs) + + +def get_single_threaded_ctx(**kwargs) -> PipelineContext: + kwargs["batch_size"] = 0 + return get_ctx(**kwargs) + + +def get_threaded_ctx(**kwargs) -> PipelineContext: + kwargs["batch_size"] = 6 + kwargs["batch_num_workers"] = 2 + return get_ctx(**kwargs) + + +@pytest.fixture +def single_threaded_ctx() -> PipelineContext: + return get_single_threaded_ctx() + + +@pytest.fixture +def threaded_ctx() -> PipelineContext: + return get_threaded_ctx() + + +@pytest.fixture +def sample_dataset(): + return Dataset.from_list([{"foo": i} for i in range(10)]) diff --git a/tests/test_filterblock.py b/tests/test_filterblock.py index c4d0b1f8..8cba6b3c 100644 --- a/tests/test_filterblock.py +++ b/tests/test_filterblock.py @@ -14,7 +14,7 @@ class TestFilterByValueBlock(unittest.TestCase): def setUp(self): self.ctx = MagicMock() - self.ctx.num_procs = 1 + self.ctx.dataset_num_procs = 1 self.pipe = MagicMock() self.block = FilterByValueBlock( self.ctx, diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py new file mode 100644 index 00000000..9385ba06 --- /dev/null +++ b/tests/test_generate_data.py @@ -0,0 +1,44 @@ +""" +Unit tests for the top-level generate_data module. +""" + +# Standard +from unittest import mock + +# First Party +from instructlab.sdg.generate_data import _sdg_init +from instructlab.sdg.pipeline import PipelineContext + + +def test_sdg_init_batch_size_optional(): + """Test that the _sdg_init function can handle a missing batch size by + delegating to the default in PipelineContext. + """ + sdgs = _sdg_init( + "simple", + None, + "mixtral", + "foo.bar", + 1, + batch_size=None, + batch_num_workers=None, + ) + assert all( + pipe.ctx.batch_size == PipelineContext.DEFAULT_BATCH_SIZE + for sdg in sdgs + for pipe in sdg.pipelines + ) + + +def test_sdg_init_batch_size_optional(): + """Test that the _sdg_init function can handle a passed batch size""" + sdgs = _sdg_init( + "simple", + None, + "mixtral", + "foo.bar", + 1, + batch_size=20, + batch_num_workers=32, + ) + assert all(pipe.ctx.batch_size == 20 for sdg in sdgs for pipe in sdg.pipelines) diff --git a/tests/test_importblock.py b/tests/test_importblock.py index 3cd38c29..b454b327 100644 --- a/tests/test_importblock.py +++ b/tests/test_importblock.py @@ -11,11 +11,14 @@ from instructlab.sdg.importblock import ImportBlock from instructlab.sdg.pipeline import Pipeline +# Local +from .conftest import get_single_threaded_ctx + class TestImportBlockWithMockPipeline(unittest.TestCase): @patch("instructlab.sdg.pipeline.Pipeline") def setUp(self, mock_pipeline): - self.ctx = MagicMock() + self.ctx = get_single_threaded_ctx() self.pipe = MagicMock() self.block_name = "test_block" self.path = "/path/to/config" @@ -79,8 +82,7 @@ def test_generate(self): class TestImportBlockWithFilterByValue(unittest.TestCase): def setUp(self): - self.ctx = MagicMock() - self.ctx.num_procs = 1 + self.ctx = get_single_threaded_ctx() self.child_yaml = self._write_tmp_yaml(_CHILD_YAML) self.parent_yaml = self._write_tmp_yaml(_PARENT_YAML_FMT % self.child_yaml) self.dataset = Dataset.from_dict( diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 443eda26..bcbe454e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -3,6 +3,8 @@ """ # Standard +from contextlib import contextmanager +from threading import Event from unittest import mock # Third Party @@ -13,8 +15,94 @@ from instructlab.sdg.block import Block from instructlab.sdg.pipeline import Pipeline, PipelineBlockError +## Helpers ## -def test_pipeline_named_errors_match_type(): + +@contextmanager +def block_types(block_types_dict): + with mock.patch( + "instructlab.sdg.pipeline._block_types", + block_types_dict, + ): + yield + + +## Pipeline Batching ## + + +def test_pipeline_no_batching(sample_dataset, single_threaded_ctx): + """Test that with no batching enabled, the block is called once""" + block_type_mock = mock.MagicMock() + block_type_mock().generate.return_value = sample_dataset + pipe_cfg = [ + { + "name": "block-one", + "type": "test", + "config": {}, + } + ] + with block_types({"test": block_type_mock}): + Pipeline(single_threaded_ctx, "", pipe_cfg).generate(sample_dataset) + block_type_mock().generate.assert_called_once_with(sample_dataset) + + +def test_pipeline_with_batching(sample_dataset, threaded_ctx): + """Test that when configured with batching enabled, the block is called + multiple times, once for each batch + """ + block_type_mock = mock.MagicMock() + block_type_mock().generate.return_value = sample_dataset + pipe_cfg = [ + { + "name": "block-one", + "type": "test", + "config": {}, + } + ] + with block_types({"test": block_type_mock}): + Pipeline(threaded_ctx, "", pipe_cfg).generate(sample_dataset) + block_type_mock().generate.call_count > 1 + + +def test_pipeline_batching_order_correct(sample_dataset, threaded_ctx): + """Make sure that batches are recombined in the correct order""" + + class MockBlockType: + # NOTE: This needs to be a class variable because it will be different + # instances of the block for each batch + _second_half_event = Event() + + def __init__(self, *_, **__): + pass + + def generate(self, dataset): + # Make sure the second half is processed before the first half + if dataset[0]["foo"] == 0: + print("A") + self._second_half_event.wait() + print("B") + else: + print("C") + self._second_half_event.set() + print("D") + return dataset.map(lambda r: {"foo": r["foo"] * 2}) + + pipe_cfg = [ + { + "name": "block-one", + "type": "test", + "config": {}, + } + ] + with block_types({"test": MockBlockType}): + res = Pipeline(threaded_ctx, "", pipe_cfg).generate(sample_dataset) + assert res.to_list() == [{"foo": i * 2} for i in range(10)] + + +## Pipeline Error Handling ## + + +def test_pipeline_named_errors_match_type(single_threaded_ctx): """Validate that a PipelineBlockError is raised to wrap exceptions raised in a Block's generate method """ @@ -29,14 +117,13 @@ def test_pipeline_named_errors_match_type(): {"name": "I work", "type": "working", "config": {}}, {"name": "I don't", "type": "failure", "config": {}}, ] - with mock.patch( - "instructlab.sdg.pipeline._block_types", + with block_types( { "working": working_block, "failure": failure_block, }, ): - pipe = Pipeline(None, None, pipe_cfg) + pipe = Pipeline(single_threaded_ctx, "", pipe_cfg) with pytest.raises(PipelineBlockError) as exc_ctx: pipe.generate(None) @@ -45,7 +132,7 @@ def test_pipeline_named_errors_match_type(): assert exc_ctx.value.block is failure_block() -def test_pipeline_config_error_handling(): +def test_pipeline_config_error_handling(single_threaded_ctx): """Validate that a PipelineBlockError is raised when block config is incorrect """ @@ -53,13 +140,16 @@ def test_pipeline_config_error_handling(): {"name_not_there": "I work", "type": "working", "config": {}}, {"name": "I don't", "type": "failure", "config": {}}, ] - pipe = Pipeline(None, None, pipe_cfg) + pipe = Pipeline(single_threaded_ctx, "", pipe_cfg) with pytest.raises(PipelineBlockError) as exc_ctx: pipe.generate(None) assert isinstance(exc_ctx.value.__cause__, KeyError) +## PipelineBlockError ## + + def test_block_generation_error_properties_from_block(): """Make sure the PipelineBlockError exposes its properties and string form correctly when pulled from a Block instance diff --git a/tests/test_sample_populator_block.py b/tests/test_sample_populator_block.py index 409adf9d..c78d995d 100644 --- a/tests/test_sample_populator_block.py +++ b/tests/test_sample_populator_block.py @@ -12,7 +12,7 @@ class TestFilterByValueBlock(unittest.TestCase): def setUp(self): self.ctx = MagicMock() - self.ctx.num_procs = 1 + self.ctx.dataset_num_procs = 1 self.pipe = MagicMock() @patch("instructlab.sdg.block.Block._load_config")