Skip to content

Commit

Permalink
Merge pull request instructlab#157 from gabe-l-hart/LLMBlockConcurren…
Browse files Browse the repository at this point in the history
…cy-135

LLMBlock concurrency
  • Loading branch information
derekhiggins authored Jul 19, 2024
2 parents e4765b9 + a27a1b8 commit 30fb578
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 45 deletions.
8 changes: 6 additions & 2 deletions src/instructlab/sdg/filterblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,13 @@ def generate(self, samples) -> Dataset:
samples,
self.column_name,
self.dtype,
self.ctx.num_procs,
self.ctx.dataset_num_procs,
)

return _filter_by_values(
samples, self.column_name, self.operation, self.value, self.ctx.num_procs
samples,
self.column_name,
self.operation,
self.value,
self.ctx.dataset_num_procs,
)
35 changes: 28 additions & 7 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,15 @@ def _check_pipeline_dir(pipeline):
)


def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_generate):
def _sdg_init(
pipeline: Pipeline,
client: openai.OpenAI,
model_family: str,
model_id: str,
num_instructions_to_generate: int,
batch_num_workers: Optional[int],
batch_size: Optional[int],
):
pipeline_pkg = None

# Search for the pipeline in User and Site data directories
Expand All @@ -200,7 +208,18 @@ def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_gene
)
_check_pipeline_dir(pipeline)

ctx = PipelineContext(client, model_family, model_id, num_instructions_to_generate)
extra_kwargs = {}
if batch_size is not None:
extra_kwargs["batch_size"] = batch_size
extra_kwargs["batch_num_workers"] = batch_num_workers

ctx = PipelineContext(
client=client,
model_family=model_family,
model_id=model_id,
num_instructions_to_generate=num_instructions_to_generate,
**extra_kwargs,
)

def load_pipeline(yaml_basename):
if pipeline_pkg:
Expand All @@ -227,8 +246,6 @@ def generate_data(
api_key: Optional[str] = None,
model_family: Optional[str] = None,
model_name: Optional[str] = None,
# TODO - not used -- when batching is enabled, this is relevant.
# Right now the code hard codes 8 cpus for batching
num_cpus: Optional[int] = None,
num_instructions_to_generate: Optional[int] = 30,
taxonomy: Optional[str] = None,
Expand All @@ -247,6 +264,7 @@ def generate_data(
tls_client_key: Optional[str] = None,
tls_client_passwd: Optional[str] = None,
pipeline: Optional[str] = "simple",
batch_size: Optional[int] = None,
) -> None:
"""Generate data for training and testing a model.
Expand All @@ -264,6 +282,10 @@ def generate_data(
"""
generate_start = time.time()

# FIXME: remove this when ilab knows to pass batch_size=0 with llama.cpp
if batch_size is None:
batch_size = 0

if not os.path.exists(output_dir):
os.mkdir(output_dir)

Expand Down Expand Up @@ -302,15 +324,14 @@ def generate_data(
else:
model_family = MODEL_FAMILY_MERLINITE

# TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI
# about whether we can turn this on (whether vllm is used or not)

sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(
pipeline,
client,
model_family,
model_name,
num_instructions_to_generate,
batch_size=batch_size,
batch_num_workers=num_cpus,
)

if console_output:
Expand Down
137 changes: 115 additions & 22 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from importlib import resources
from typing import Optional
from typing import Iterable, Optional
import math
import os.path

# Third Party
from datasets import Dataset
from datasets import Dataset, concatenate_datasets
from openai import OpenAI
import yaml

# Local
Expand All @@ -22,16 +26,47 @@ class EmptyDatasetError(Exception):


# This is part of the public API.
class PipelineContext:
def __init__(
self, client, model_family, model_id, num_instructions_to_generate
) -> None:
self.client = client
self.model_family = model_family
self.model_id = model_id
self.num_instructions_to_generate = num_instructions_to_generate
# FIXME: base this on the available number of CPUs
self.num_procs = 8
@dataclass
class PipelineContext: # pylint: disable=too-many-instance-attributes
"""
A PipelineContext holds the common attributes needed between blocks in a
pipeline
client: The OpenAI client handle.
model_id: The ID of the teacher model to be used for client calls.
model_family: The family identifier for the model being updated.
num_instructions_to_generate: The total number of instructions the user
wants to generate during this run.
batch_size: The size of the dataset batches for parallel generation. Set to
0 to disable batching.
batch_num_workers: The number of worker threads/processes to maintain in the
central executor pool.
dataset_num_procs: The number of processes to use when performing parallel
map operations on individual datasets.
"""

# The default batch size of 8 has been determined as a good default for
# standard instructlab workloads when running with vllm batching.
DEFAULT_BATCH_SIZE = 8

# The default number of processes to use when performing parallel operations
# on individual datasets
DEFAULT_DATASET_NUM_PROCS = 8

client: OpenAI
model_family: str
model_id: str
num_instructions_to_generate: int
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
batch_size: int = DEFAULT_BATCH_SIZE
batch_num_workers: Optional[int] = None

@property
def batching_enabled(self) -> bool:
"""Batching is enabled IFF the batch size is specified and the number of
workers is not set explicitly to 1
"""
return self.batch_size > 0 and self.batch_num_workers != 1


# This is part of the public API.
Expand Down Expand Up @@ -63,7 +98,12 @@ def exception_message(self) -> str:

# This is part of the public API.
class Pipeline:
def __init__(self, ctx, config_path, chained_blocks: list) -> None:
def __init__(
self,
ctx: PipelineContext,
config_path: str,
chained_blocks: list[dict],
) -> None:
"""
Initialize the Pipeline class with a configuration dictionary.
config_dict: the run config py or yaml loaded into a dictionary
Expand All @@ -81,20 +121,40 @@ def from_file(cls, ctx, pipeline_yaml):
pipeline_yaml = os.path.join(resources.files(__package__), pipeline_yaml)
return cls(ctx, pipeline_yaml, _parse_pipeline_config_file(pipeline_yaml))

def _drop_duplicates(self, dataset, cols):
"""
Drop duplicates from the dataset based on the columns provided.
"""
df = dataset.to_pandas()
df = df.drop_duplicates(subset=cols).reset_index(drop=True)
ds = Dataset.from_pandas(df)
return ds

def generate(self, dataset) -> Dataset:
"""
Generate the dataset by running the pipeline steps.
dataset: the input dataset
"""
# If not batching, simply delegate to _generate_single
if not self.ctx.batching_enabled:
logger.info("Running pipeline single-threaded")
return self._generate_single(dataset)

# Otherwise, split the dataset into batches and run each batch as a
# future in the thread pool
logger.info(
"Running pipeline with multi-threaded batching. Using %s workers for batches of size %s",
self.ctx.batch_num_workers,
self.ctx.batch_size,
)
input_splits = self._split_dataset(dataset)
with ThreadPoolExecutor(max_workers=self.ctx.batch_num_workers) as executor:
futures = [
executor.submit(self._generate_single, input_split)
for input_split in input_splits
]

# Collect the results of each batch as they finish. This needs to
# wait for them all, so the order of waiting doesn't matter
output_splits = [future.result() for future in futures]

return concatenate_datasets(output_splits)

## Implementation Details ##

def _generate_single(self, dataset) -> Dataset:
"""Generate a single dataset by running the pipeline steps."""
for block_prop in self.chained_blocks:
# Initialize arguments for error handling to None
block, block_name, block_type = None, None, None
Expand Down Expand Up @@ -134,6 +194,39 @@ def generate(self, dataset) -> Dataset:

return dataset

def _drop_duplicates(self, dataset, cols):
"""
Drop duplicates from the dataset based on the columns provided.
"""
df = dataset.to_pandas()
df = df.drop_duplicates(subset=cols).reset_index(drop=True)
ds = Dataset.from_pandas(df)
return ds

def _split_dataset(self, dataset: Dataset) -> list[Dataset]:
"""Split the dataset into smaller batches."""
assert (
self.ctx.batch_size is not None
), "Programming Error: Should not call _split_dataset if batching disabled"
total_size = len(dataset)
num_batches = math.ceil(total_size / self.ctx.batch_size)
batches = [
dataset.select(self._get_batch_indices(i, total_size))
for i in range(num_batches)
]
return batches

def _get_batch_indices(self, batch_index: int, total_size: int) -> Iterable[int]:
assert (
self.ctx.batch_size is not None
), "Programming Error: Should not call _get_batch_indices if batching disabled"
return range(
# Start index offset by the batch size
batch_index * self.ctx.batch_size,
# End index is the next batch offset or the end of the dataset
min((batch_index + 1) * self.ctx.batch_size, total_size),
)


_block_types = {
"CombineColumnsBlock": utilblocks.CombineColumnsBlock,
Expand Down
10 changes: 7 additions & 3 deletions src/instructlab/sdg/utilblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def populate(sample):

def generate(self, samples) -> Dataset:
return self._map_populate(
samples, self.configs, self.column_name, self.ctx.num_procs
samples, self.configs, self.column_name, self.ctx.dataset_num_procs
)


Expand Down Expand Up @@ -64,7 +64,7 @@ def generate(self, samples: Dataset) -> Dataset:
self.choice_map,
self.choice_col,
self.output_col,
self.ctx.num_procs,
self.ctx.dataset_num_procs,
)


Expand All @@ -89,5 +89,9 @@ def combine(sample):

def generate(self, samples: Dataset) -> Dataset:
return self._map_combine(
samples, self.columns, self.output_col, self.separator, self.ctx.num_procs
samples,
self.columns,
self.output_col,
self.separator,
self.ctx.dataset_num_procs,
)
48 changes: 48 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Common fixtures and testing utilities
"""

# Standard
from unittest import mock

# Third Party
from datasets import Dataset
import pytest

# First Party
from instructlab.sdg.pipeline import PipelineContext


def get_ctx(**kwargs) -> PipelineContext:
kwargs.setdefault("client", mock.MagicMock())
kwargs.setdefault("model_family", "test")
kwargs.setdefault("model_id", "test-model")
kwargs.setdefault("num_instructions_to_generate", 10)
kwargs.setdefault("dataset_num_procs", 1)
return PipelineContext(**kwargs)


def get_single_threaded_ctx(**kwargs) -> PipelineContext:
kwargs["batch_size"] = 0
return get_ctx(**kwargs)


def get_threaded_ctx(**kwargs) -> PipelineContext:
kwargs["batch_size"] = 6
kwargs["batch_num_workers"] = 2
return get_ctx(**kwargs)


@pytest.fixture
def single_threaded_ctx() -> PipelineContext:
return get_single_threaded_ctx()


@pytest.fixture
def threaded_ctx() -> PipelineContext:
return get_threaded_ctx()


@pytest.fixture
def sample_dataset():
return Dataset.from_list([{"foo": i} for i in range(10)])
2 changes: 1 addition & 1 deletion tests/test_filterblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class TestFilterByValueBlock(unittest.TestCase):
def setUp(self):
self.ctx = MagicMock()
self.ctx.num_procs = 1
self.ctx.dataset_num_procs = 1
self.pipe = MagicMock()
self.block = FilterByValueBlock(
self.ctx,
Expand Down
Loading

0 comments on commit 30fb578

Please sign in to comment.