From c8145f84b194a7dbc2a937d394762b21f996bd21 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 28 Jan 2025 14:06:10 -0500 Subject: [PATCH] Make everything but client optional on PipelineContext This makes model_family, model_id, and num_instructions_to_generate optional on PipelineContext instead of required. This should be a backwards-compatible change, and includes tests that verify passing the previously-required parameters into PipelineContext work as expected. This is a move towards supporting and expecting Pipelines that have LLMBlocks that may use different model families or different model ids. If a Block's yaml config specifies a model_family or model_id and none is set on the PipelineContext, then the Block's values will get used. However, if these are set on the PipelineContext, they're treated as overrides and will override the values for all Blocks in this Pipeline. This behavior emulates the previous behavior, but we should steer users away from ever setting model_id or model_family on a PipelineContext if the Pipeline contains multiple LLMBlock entries. Fixes #511 Signed-off-by: Ben Browning --- CHANGELOG.md | 12 ++ docs/examples/multiple_llms/README.md | 5 + docs/examples/multiple_llms/pipeline.yaml | 34 +++ scripts/validate_pipelines.py | 1 + src/instructlab/sdg/blocks/llmblock.py | 57 ++++- src/instructlab/sdg/pipeline.py | 11 +- src/instructlab/sdg/pipelines/schema/v1.json | 12 ++ tests/functional/test_granular_api.py | 1 + tests/test_llmblock.py | 210 +++++++++++++++++++ tests/test_pipeline.py | 89 +++++++- tests/testdata/custom_block.py | 2 +- 11 files changed, 421 insertions(+), 13 deletions(-) create mode 100644 docs/examples/multiple_llms/README.md create mode 100644 docs/examples/multiple_llms/pipeline.yaml diff --git a/CHANGELOG.md b/CHANGELOG.md index c513bc63..1518ed52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,15 @@ +## Upcoming v0.8.x + +### Features + +### LLMBlocks can now specify `model_family` or `model_id` in their config + +Each `LLMBlock` in a `Pipeline` can now specify `model_family` or `model_id` in their yaml configuration to set the values to use for these blocks, as opposed to setting this for the entire `Pipeline` in the `PipelineContext` object. This is useful for the cases where multiple `LLMBlocks` exist in the same `Pipeline` where each one uses a different model. + +#### Fewer required parameters for `PipelineContext` + +The parameters `model_family`, `model_id`, and `num_instructions_to_generate` are no longer required in `PipelineContext` objects. They used to be required, and if passed in will still get used as before. However, they can now be omitted if your `Pipeline` contains no `LLMBlock` entries or if your `LLMBlock` config specifies these values in the `Pipeline` yaml. + ## v0.7.0 ### Features diff --git a/docs/examples/multiple_llms/README.md b/docs/examples/multiple_llms/README.md new file mode 100644 index 00000000..19a55cd7 --- /dev/null +++ b/docs/examples/multiple_llms/README.md @@ -0,0 +1,5 @@ +# Multiple LLMs in a single Pipeline + +To create a `Pipeline` that uses multiple different LLMs for inference, each `LLMBlock` in the `Pipeline` yaml should specify a `model_family` (to indicate which set of prompts to use) and a `model_id` (corresponding to the id of this model in the deployed inference server). For more control over the prompting, a direct `model_prompt` may be used in place of `model_family` to specify the exact prompt to use as opposed to choosing a generic one for that model family. + +See `pipeline.yaml` in this directory for an example of a Pipeline that calls into 3 separate LLMs. diff --git a/docs/examples/multiple_llms/pipeline.yaml b/docs/examples/multiple_llms/pipeline.yaml new file mode 100644 index 00000000..1d2af75c --- /dev/null +++ b/docs/examples/multiple_llms/pipeline.yaml @@ -0,0 +1,34 @@ +version: "1.0" +blocks: + - name: model_one + type: LLMBlock + config: + model_family: mixtral + model_id: Mixtral-8x7B-Instruct-v0.1 + config_path: model_one_config.yaml + output_cols: + - column_one + gen_kwargs: + max_tokens: 2048 + + - name: model_two + type: LLMBlock + config: + model_family: granite + model_id: granite-7b-lab + config_path: model_two_config.yaml + output_cols: + - column_two + gen_kwargs: + max_tokens: 512 + + - name: model_three + type: LLMBlock + config: + model_id: granite-7b-lab + model_prompt: [INST] {prompt} [/INST] + config_path: model_three_config.yaml + output_cols: + - column_three + gen_kwargs: + max_tokens: 5 diff --git a/scripts/validate_pipelines.py b/scripts/validate_pipelines.py index 4cb6467f..fcfcf41c 100755 --- a/scripts/validate_pipelines.py +++ b/scripts/validate_pipelines.py @@ -30,6 +30,7 @@ def main(): schema = json.load(file) yaml_files = glob.glob("src/instructlab/sdg/pipelines/**/*.yaml", recursive=True) + yaml_files.extend(glob.glob("docs/examples/**/pipeline.yaml", recursive=True)) all_valid = True for yaml_file in yaml_files: print("=======================================================") diff --git a/src/instructlab/sdg/blocks/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py index 3350a771..de3b2fdf 100644 --- a/src/instructlab/sdg/blocks/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -15,6 +15,7 @@ # Import prompts to register default chat templates from .. import prompts as default_prompts # pylint: disable=unused-import from ..registry import BlockRegistry, PromptRegistry +from ..utils import models from .block import Block, BlockConfigParserError logger = logging.getLogger(__name__) @@ -62,6 +63,33 @@ def template_from_struct_and_config(struct, config): return PromptRegistry.template_from_string(struct.format(**filtered_config)) +def _resolve_model_id(model_id, ctx_model_id, block): + # If a model id was passed in the PipelineContext, use that + if ctx_model_id: + return ctx_model_id + + # If we have no model id at all, raise an error + if not model_id: + raise BlockConfigParserError( + f"{type(block).__name__} {block.block_name} requires a model_id but none was specified in the block config nor passed via the PipelineContext" + ) + + # Otherwise fallback to the model_id specified in the block config + return model_id + + +def _resolve_model_family(model_family, ctx_model_family): + # If a model family was passed in the PipelineContext, use that + if ctx_model_family: + return ctx_model_family + + # Otherwise fallback to the model_family specified in the block config + # This could be None, but that's ok for model_family because when the + # actual prompt gets selected it falls back to merlinite if nothing + # else was given. + return model_family + + # This is part of the public API. @BlockRegistry.register("LLMBlock") # pylint: disable=dangerous-default-value @@ -74,6 +102,8 @@ def __init__( block_name, config_path, output_cols, + model_id=None, + model_family=None, model_prompt=None, gen_kwargs={}, parser_kwargs={}, @@ -87,6 +117,11 @@ def __init__( self.prompt_template = template_from_struct_and_config( self.prompt_struct, self.block_config ) + self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self) + self.model_family = models.get_model_family( + _resolve_model_family(model_family, self.ctx.model_family), + self.model_id, + ) self.model_prompt = model_prompt self.output_cols = output_cols self.batch_params = batch_kwargs @@ -104,14 +139,14 @@ def __init__( self.gen_kwargs = self._gen_kwargs( max_num_token_override, gen_kwargs, - model=self.ctx.model_id, + model=self.model_id, temperature=0, max_tokens=DEFAULT_MAX_NUM_TOKENS, ) # Whether the LLM server supports a list of input prompts # and supports the n parameter to generate n outputs per input self.server_supports_batched = server_supports_batched( - self.ctx.client, self.ctx.model_id + self.ctx.client, self.model_id ) def _parse(self, generated_string) -> dict: @@ -160,7 +195,7 @@ def _format_prompt(self, sample: Dict) -> str: model_prompt = None if self.model_prompt is None: - model_prompt = PromptRegistry.get_template(self.ctx.model_family) + model_prompt = PromptRegistry.get_template(self.model_family) elif self.model_prompt: model_prompt = PromptRegistry.template_from_string(self.model_prompt) else: @@ -183,6 +218,10 @@ def _gen_kwargs(self, max_num_token_override, gen_kwargs, **defaults): and isinstance(gen_kwargs["n"], str) and gen_kwargs["n"] == "scaled" ): + if not self.ctx.num_instructions_to_generate: + raise BlockConfigParserError( + f"""LLMBlock {self.block_name} has a gen_kwargs["n"] value of "scaled" but num_instructions_to_generate was not set in the PipelineContext""" + ) gen_kwargs["n"] = self.ctx.num_instructions_to_generate if "temperature" in gen_kwargs: gen_kwargs["temperature"] = float(gen_kwargs["temperature"]) @@ -285,6 +324,8 @@ def __init__( config_paths, output_cols, selector_column_name, + model_id=None, + model_family=None, model_prompt=None, gen_kwargs={}, parser_kwargs={}, @@ -305,6 +346,8 @@ def __init__( block_name, config_paths[0][0], output_cols, + model_id=model_id, + model_family=model_family, model_prompt=model_prompt, gen_kwargs=gen_kwargs, parser_kwargs=parser_kwargs, @@ -360,6 +403,8 @@ def __init__( block_name, config_path, output_cols, + model_id=None, + model_family=None, model_prompt=None, gen_kwargs={}, parser_kwargs={}, @@ -371,6 +416,8 @@ def __init__( block_name, config_path, output_cols, + model_id=model_id, + model_family=model_family, model_prompt=model_prompt, gen_kwargs=gen_kwargs, parser_kwargs=parser_kwargs, @@ -466,14 +513,16 @@ def __init__( block_name, input_col, output_col, + model_id=None, gen_kwargs={}, ) -> None: super().__init__(ctx, pipe, block_name) + self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self) self.input_col = input_col self.output_col = output_col self.gen_kwargs = self._gen_kwargs( gen_kwargs, - model=self.ctx.model_id, + model=self.model_id, temperature=0, max_tokens=DEFAULT_MAX_NUM_TOKENS, ) diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index d5c002ab..52f9db34 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -16,7 +16,7 @@ # First Party from instructlab.sdg.checkpointing import Checkpointer -from instructlab.sdg.utils import models, pandas +from instructlab.sdg.utils import pandas # Local from .blocks import llmblock @@ -61,9 +61,9 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes DEFAULT_DATASET_NUM_PROCS = 8 client: OpenAI - model_family: str - model_id: str - num_instructions_to_generate: int + model_family: Optional[str] = None + model_id: Optional[str] = None + num_instructions_to_generate: Optional[int] = None dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS checkpoint_dir: Optional[str] = None save_freq: Optional[int] = 1 @@ -71,9 +71,6 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes batch_size: int = DEFAULT_BATCH_SIZE batch_num_workers: Optional[int] = None - def __post_init__(self): - self.model_family = models.get_model_family(self.model_family, self.model_id) - @property def batching_enabled(self) -> bool: """Batching is enabled IFF the batch size is specified and the number of diff --git a/src/instructlab/sdg/pipelines/schema/v1.json b/src/instructlab/sdg/pipelines/schema/v1.json index 529f57c1..369ef0aa 100644 --- a/src/instructlab/sdg/pipelines/schema/v1.json +++ b/src/instructlab/sdg/pipelines/schema/v1.json @@ -94,6 +94,12 @@ "type": "string" } }, + "model_id": { + "type": "string" + }, + "model_family": { + "type": "string" + }, "model_prompt": { "type": "string" }, @@ -177,6 +183,12 @@ "type": "string" } }, + "model_id": { + "type": "string" + }, + "model_family": { + "type": "string" + }, "model_prompt": { "type": "string" }, diff --git a/tests/functional/test_granular_api.py b/tests/functional/test_granular_api.py index 6b83bbc6..4238c2be 100644 --- a/tests/functional/test_granular_api.py +++ b/tests/functional/test_granular_api.py @@ -91,6 +91,7 @@ def test_granular_api_end_to_end(self): input_dir=preprocessed_dir, output_dir=generated_dir, pipeline=pipeline_dir, + model_id="mock", num_cpus=1, # Test is faster running on a single CPU vs forking batch_size=0, # Disable batch for tiny dataset and fastest test ) diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index 6505861b..4e2ab49a 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -22,6 +22,7 @@ LLMMessagesBlock, ) from src.instructlab.sdg.blocks.llmblock import server_supports_batched +from src.instructlab.sdg.utils import models @patch("src.instructlab.sdg.blocks.block.Block._load_config") @@ -202,6 +203,162 @@ def test_validate(self, mock_load_config): assert not block._validate(block.prompt_template, {}) assert block._validate(block.prompt_template, {"var1": "foo", "var2": "bar"}) + def test_n_scaled_with_num_instructions(self, mock_load_config): + mock_load_config.return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + # save state for future tests + old_num_instructions = self.mock_ctx.num_instructions_to_generate + + self.mock_ctx.num_instructions_to_generate = None + with pytest.raises(BlockConfigParserError) as exc: + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + gen_kwargs={"n": "scaled"}, + ) + assert "num_instructions_to_generate was not set" in str(exc.value) + + self.mock_ctx.num_instructions_to_generate = 5 + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + gen_kwargs={"n": "scaled"}, + ) + assert block + + # restore state for future tests + self.mock_ctx.num_instructions_to_generate = old_num_instructions + + def test_resolve_model_id(self, mock_load_config): + mock_load_config.return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + # save state for future tests + old_model_id = self.mock_ctx.model_id + + # Raise an error when no model_id on block or PipelineContext + self.mock_ctx.model_id = None + with pytest.raises(BlockConfigParserError) as exc: + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert "requires a model_id but none" in str(exc.value) + + # model_id on block is default + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + assert block.model_id == "mistralai/Mixtral-8x7B-Instruct-v0.1" + + # model_id on PipelineContext gets used + self.mock_ctx.model_id = "test_model" + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert block.model_id == "test_model" + + # model_id on PipelineContext overrides + self.mock_ctx.model_id = "test_model" + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + assert block.model_id == "test_model" + + # restore state for future tests + self.mock_ctx.model_id = old_model_id + + def test_resolve_model_family(self, mock_load_config): + mock_load_config.return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + # save state for future tests + old_model_family = self.mock_ctx.model_family + + # Uses a default model_family when none specified + self.mock_ctx.model_family = None + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert block.model_family == models.DEFAULT_MODEL_FAMILY + + # model_family on block is used by defalt + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + model_family="granite", + ) + assert block.model_family == "granite" + + # model_family on PipelineContext gets used + self.mock_ctx.model_family = "mixtral" + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert block.model_family == "mixtral" + + # model_family on PipelineContext overrides + self.mock_ctx.model_family = "mixtral" + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + model_family="granite", + ) + assert block.model_family == "mixtral" + + # restore state for future tests + self.mock_ctx.model_family = old_model_family + class TestLLMBlockBatching(unittest.TestCase): def setUp(self): @@ -481,3 +638,56 @@ def test_calls_chat_completion_api(self): assert "output" in output.column_names mock_completion.create.assert_called() assert mock_completion.create.call_args.kwargs["messages"] == "my message" + + def test_resolve_model_id(self): + # save state for future tests + old_model_id = self.mock_ctx.model_id + + # Raise an error when no model_id on block or PipelineContext + self.mock_ctx.model_id = None + with pytest.raises(BlockConfigParserError) as exc: + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + ) + assert "requires a model_id but none" in str(exc.value) + + # model_id on block is default + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + assert block.model_id == "mistralai/Mixtral-8x7B-Instruct-v0.1" + + # model_id on PipelineContext gets used + self.mock_ctx.model_id = "test_model" + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + ) + assert block.model_id == "test_model" + + # model_id on PipelineContext overrides + self.mock_ctx.model_id = "test_model" + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + assert block.model_id == "test_model" + + # restore state for future tests + self.mock_ctx.model_id = old_model_id diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c95ed486..0ede084d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -14,7 +14,7 @@ import pytest # First Party -from instructlab.sdg import Block, Pipeline, PipelineBlockError +from instructlab.sdg import Block, Pipeline, PipelineBlockError, PipelineContext ## Helpers ## @@ -259,3 +259,90 @@ def test_block_generation_error_properties_from_strings(): str(gen_err) == f"{PipelineBlockError.__name__}({block_type}/{block_name}): {inner_err}" ) + + +def test_pipeline_context_backwards_compat(): + """ + Test backwards compatibility of PipelineContext signatures + where client, model_family, model_id, num_instructions_to_generate + used to be required before SDG 0.8.x but since then are not. + """ + client = mock.MagicMock() + model_family = "granite" + model_id = "bar" + num_instructions_to_generate = 200 + save_freq = 5 + + # <= SDG 0.7.x - only passing old required positional params + ctx = PipelineContext(client, model_family, model_id, num_instructions_to_generate) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id == model_id + assert ctx.num_instructions_to_generate == num_instructions_to_generate + + # <= SDG 0.7.x - old required positional params with kwargs + ctx = PipelineContext( + client, + model_family, + model_id, + num_instructions_to_generate, + save_freq=save_freq, + ) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id == model_id + assert ctx.num_instructions_to_generate == num_instructions_to_generate + assert ctx.save_freq == save_freq + + # <= SDG 0.7.x - passing all params as kwargs + ctx = PipelineContext( + client=client, + model_family=model_family, + model_id=model_id, + num_instructions_to_generate=num_instructions_to_generate, + save_freq=save_freq, + ) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id == model_id + assert ctx.num_instructions_to_generate == num_instructions_to_generate + assert ctx.save_freq == save_freq + + # SDG 0.8.x - num_instructions_to_generate no longer required + ctx = PipelineContext(client, model_family, model_id) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id == model_id + assert ctx.num_instructions_to_generate is None + + # SDG 0.8.x - num_instructions_to_generate no longer required with kwargs + ctx = PipelineContext(client, model_family, model_id, save_freq=save_freq) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id == model_id + assert ctx.num_instructions_to_generate is None + assert ctx.save_freq == save_freq + + # SDG 0.8.x - model_id no longer required + ctx = PipelineContext(client, model_family) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id is None + + # SDG 0.8.x - model_id no longer required with kwargs + ctx = PipelineContext(client, model_family, save_freq=save_freq) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id is None + assert ctx.save_freq == save_freq + + # SDG 0.8.x - model_family no longer required + ctx = PipelineContext(client) + assert ctx.client == client + assert ctx.model_family is None + + # SDG 0.8.x - model_family no longer required with kwargs + ctx = PipelineContext(client, save_freq=save_freq) + assert ctx.client == client + assert ctx.model_family is None + assert ctx.save_freq == save_freq diff --git a/tests/testdata/custom_block.py b/tests/testdata/custom_block.py index 409d4ebc..2600d341 100644 --- a/tests/testdata/custom_block.py +++ b/tests/testdata/custom_block.py @@ -16,7 +16,7 @@ def generate(self, samples: Dataset): return samples -pipeline_context = PipelineContext(None, "mixtral", "my_model", 5) +pipeline_context = PipelineContext(None) pipeline_yaml = pathlib.Path(__file__).parent.joinpath("custom_block_pipeline.yaml") pipeline = Pipeline.from_file(pipeline_context, pipeline_yaml) input_ds = Dataset.from_list(