diff --git a/CHANGELOG.md b/CHANGELOG.md index c513bc63..1518ed52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,15 @@ +## Upcoming v0.8.x + +### Features + +### LLMBlocks can now specify `model_family` or `model_id` in their config + +Each `LLMBlock` in a `Pipeline` can now specify `model_family` or `model_id` in their yaml configuration to set the values to use for these blocks, as opposed to setting this for the entire `Pipeline` in the `PipelineContext` object. This is useful for the cases where multiple `LLMBlocks` exist in the same `Pipeline` where each one uses a different model. + +#### Fewer required parameters for `PipelineContext` + +The parameters `model_family`, `model_id`, and `num_instructions_to_generate` are no longer required in `PipelineContext` objects. They used to be required, and if passed in will still get used as before. However, they can now be omitted if your `Pipeline` contains no `LLMBlock` entries or if your `LLMBlock` config specifies these values in the `Pipeline` yaml. + ## v0.7.0 ### Features diff --git a/docs/examples/multiple_llms/README.md b/docs/examples/multiple_llms/README.md new file mode 100644 index 00000000..19a55cd7 --- /dev/null +++ b/docs/examples/multiple_llms/README.md @@ -0,0 +1,5 @@ +# Multiple LLMs in a single Pipeline + +To create a `Pipeline` that uses multiple different LLMs for inference, each `LLMBlock` in the `Pipeline` yaml should specify a `model_family` (to indicate which set of prompts to use) and a `model_id` (corresponding to the id of this model in the deployed inference server). For more control over the prompting, a direct `model_prompt` may be used in place of `model_family` to specify the exact prompt to use as opposed to choosing a generic one for that model family. + +See `pipeline.yaml` in this directory for an example of a Pipeline that calls into 3 separate LLMs. diff --git a/docs/examples/multiple_llms/pipeline.yaml b/docs/examples/multiple_llms/pipeline.yaml new file mode 100644 index 00000000..1d2af75c --- /dev/null +++ b/docs/examples/multiple_llms/pipeline.yaml @@ -0,0 +1,34 @@ +version: "1.0" +blocks: + - name: model_one + type: LLMBlock + config: + model_family: mixtral + model_id: Mixtral-8x7B-Instruct-v0.1 + config_path: model_one_config.yaml + output_cols: + - column_one + gen_kwargs: + max_tokens: 2048 + + - name: model_two + type: LLMBlock + config: + model_family: granite + model_id: granite-7b-lab + config_path: model_two_config.yaml + output_cols: + - column_two + gen_kwargs: + max_tokens: 512 + + - name: model_three + type: LLMBlock + config: + model_id: granite-7b-lab + model_prompt: [INST] {prompt} [/INST] + config_path: model_three_config.yaml + output_cols: + - column_three + gen_kwargs: + max_tokens: 5 diff --git a/scripts/validate_pipelines.py b/scripts/validate_pipelines.py index 4cb6467f..fcfcf41c 100755 --- a/scripts/validate_pipelines.py +++ b/scripts/validate_pipelines.py @@ -30,6 +30,7 @@ def main(): schema = json.load(file) yaml_files = glob.glob("src/instructlab/sdg/pipelines/**/*.yaml", recursive=True) + yaml_files.extend(glob.glob("docs/examples/**/pipeline.yaml", recursive=True)) all_valid = True for yaml_file in yaml_files: print("=======================================================") diff --git a/src/instructlab/sdg/blocks/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py index 3350a771..de3b2fdf 100644 --- a/src/instructlab/sdg/blocks/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -15,6 +15,7 @@ # Import prompts to register default chat templates from .. import prompts as default_prompts # pylint: disable=unused-import from ..registry import BlockRegistry, PromptRegistry +from ..utils import models from .block import Block, BlockConfigParserError logger = logging.getLogger(__name__) @@ -62,6 +63,33 @@ def template_from_struct_and_config(struct, config): return PromptRegistry.template_from_string(struct.format(**filtered_config)) +def _resolve_model_id(model_id, ctx_model_id, block): + # If a model id was passed in the PipelineContext, use that + if ctx_model_id: + return ctx_model_id + + # If we have no model id at all, raise an error + if not model_id: + raise BlockConfigParserError( + f"{type(block).__name__} {block.block_name} requires a model_id but none was specified in the block config nor passed via the PipelineContext" + ) + + # Otherwise fallback to the model_id specified in the block config + return model_id + + +def _resolve_model_family(model_family, ctx_model_family): + # If a model family was passed in the PipelineContext, use that + if ctx_model_family: + return ctx_model_family + + # Otherwise fallback to the model_family specified in the block config + # This could be None, but that's ok for model_family because when the + # actual prompt gets selected it falls back to merlinite if nothing + # else was given. + return model_family + + # This is part of the public API. @BlockRegistry.register("LLMBlock") # pylint: disable=dangerous-default-value @@ -74,6 +102,8 @@ def __init__( block_name, config_path, output_cols, + model_id=None, + model_family=None, model_prompt=None, gen_kwargs={}, parser_kwargs={}, @@ -87,6 +117,11 @@ def __init__( self.prompt_template = template_from_struct_and_config( self.prompt_struct, self.block_config ) + self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self) + self.model_family = models.get_model_family( + _resolve_model_family(model_family, self.ctx.model_family), + self.model_id, + ) self.model_prompt = model_prompt self.output_cols = output_cols self.batch_params = batch_kwargs @@ -104,14 +139,14 @@ def __init__( self.gen_kwargs = self._gen_kwargs( max_num_token_override, gen_kwargs, - model=self.ctx.model_id, + model=self.model_id, temperature=0, max_tokens=DEFAULT_MAX_NUM_TOKENS, ) # Whether the LLM server supports a list of input prompts # and supports the n parameter to generate n outputs per input self.server_supports_batched = server_supports_batched( - self.ctx.client, self.ctx.model_id + self.ctx.client, self.model_id ) def _parse(self, generated_string) -> dict: @@ -160,7 +195,7 @@ def _format_prompt(self, sample: Dict) -> str: model_prompt = None if self.model_prompt is None: - model_prompt = PromptRegistry.get_template(self.ctx.model_family) + model_prompt = PromptRegistry.get_template(self.model_family) elif self.model_prompt: model_prompt = PromptRegistry.template_from_string(self.model_prompt) else: @@ -183,6 +218,10 @@ def _gen_kwargs(self, max_num_token_override, gen_kwargs, **defaults): and isinstance(gen_kwargs["n"], str) and gen_kwargs["n"] == "scaled" ): + if not self.ctx.num_instructions_to_generate: + raise BlockConfigParserError( + f"""LLMBlock {self.block_name} has a gen_kwargs["n"] value of "scaled" but num_instructions_to_generate was not set in the PipelineContext""" + ) gen_kwargs["n"] = self.ctx.num_instructions_to_generate if "temperature" in gen_kwargs: gen_kwargs["temperature"] = float(gen_kwargs["temperature"]) @@ -285,6 +324,8 @@ def __init__( config_paths, output_cols, selector_column_name, + model_id=None, + model_family=None, model_prompt=None, gen_kwargs={}, parser_kwargs={}, @@ -305,6 +346,8 @@ def __init__( block_name, config_paths[0][0], output_cols, + model_id=model_id, + model_family=model_family, model_prompt=model_prompt, gen_kwargs=gen_kwargs, parser_kwargs=parser_kwargs, @@ -360,6 +403,8 @@ def __init__( block_name, config_path, output_cols, + model_id=None, + model_family=None, model_prompt=None, gen_kwargs={}, parser_kwargs={}, @@ -371,6 +416,8 @@ def __init__( block_name, config_path, output_cols, + model_id=model_id, + model_family=model_family, model_prompt=model_prompt, gen_kwargs=gen_kwargs, parser_kwargs=parser_kwargs, @@ -466,14 +513,16 @@ def __init__( block_name, input_col, output_col, + model_id=None, gen_kwargs={}, ) -> None: super().__init__(ctx, pipe, block_name) + self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self) self.input_col = input_col self.output_col = output_col self.gen_kwargs = self._gen_kwargs( gen_kwargs, - model=self.ctx.model_id, + model=self.model_id, temperature=0, max_tokens=DEFAULT_MAX_NUM_TOKENS, ) diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index d5c002ab..52f9db34 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -16,7 +16,7 @@ # First Party from instructlab.sdg.checkpointing import Checkpointer -from instructlab.sdg.utils import models, pandas +from instructlab.sdg.utils import pandas # Local from .blocks import llmblock @@ -61,9 +61,9 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes DEFAULT_DATASET_NUM_PROCS = 8 client: OpenAI - model_family: str - model_id: str - num_instructions_to_generate: int + model_family: Optional[str] = None + model_id: Optional[str] = None + num_instructions_to_generate: Optional[int] = None dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS checkpoint_dir: Optional[str] = None save_freq: Optional[int] = 1 @@ -71,9 +71,6 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes batch_size: int = DEFAULT_BATCH_SIZE batch_num_workers: Optional[int] = None - def __post_init__(self): - self.model_family = models.get_model_family(self.model_family, self.model_id) - @property def batching_enabled(self) -> bool: """Batching is enabled IFF the batch size is specified and the number of diff --git a/src/instructlab/sdg/pipelines/schema/v1.json b/src/instructlab/sdg/pipelines/schema/v1.json index 690d3d73..cfb778f5 100644 --- a/src/instructlab/sdg/pipelines/schema/v1.json +++ b/src/instructlab/sdg/pipelines/schema/v1.json @@ -94,6 +94,12 @@ "type": "string" } }, + "model_id": { + "type": "string" + }, + "model_family": { + "type": "string" + }, "model_prompt": { "type": "string" }, @@ -177,6 +183,12 @@ "type": "string" } }, + "model_id": { + "type": "string" + }, + "model_family": { + "type": "string" + }, "model_prompt": { "type": "string" }, diff --git a/tests/functional/test_granular_api.py b/tests/functional/test_granular_api.py index 6b83bbc6..4238c2be 100644 --- a/tests/functional/test_granular_api.py +++ b/tests/functional/test_granular_api.py @@ -91,6 +91,7 @@ def test_granular_api_end_to_end(self): input_dir=preprocessed_dir, output_dir=generated_dir, pipeline=pipeline_dir, + model_id="mock", num_cpus=1, # Test is faster running on a single CPU vs forking batch_size=0, # Disable batch for tiny dataset and fastest test ) diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index 6505861b..4e2ab49a 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -22,6 +22,7 @@ LLMMessagesBlock, ) from src.instructlab.sdg.blocks.llmblock import server_supports_batched +from src.instructlab.sdg.utils import models @patch("src.instructlab.sdg.blocks.block.Block._load_config") @@ -202,6 +203,162 @@ def test_validate(self, mock_load_config): assert not block._validate(block.prompt_template, {}) assert block._validate(block.prompt_template, {"var1": "foo", "var2": "bar"}) + def test_n_scaled_with_num_instructions(self, mock_load_config): + mock_load_config.return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + # save state for future tests + old_num_instructions = self.mock_ctx.num_instructions_to_generate + + self.mock_ctx.num_instructions_to_generate = None + with pytest.raises(BlockConfigParserError) as exc: + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + gen_kwargs={"n": "scaled"}, + ) + assert "num_instructions_to_generate was not set" in str(exc.value) + + self.mock_ctx.num_instructions_to_generate = 5 + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + gen_kwargs={"n": "scaled"}, + ) + assert block + + # restore state for future tests + self.mock_ctx.num_instructions_to_generate = old_num_instructions + + def test_resolve_model_id(self, mock_load_config): + mock_load_config.return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + # save state for future tests + old_model_id = self.mock_ctx.model_id + + # Raise an error when no model_id on block or PipelineContext + self.mock_ctx.model_id = None + with pytest.raises(BlockConfigParserError) as exc: + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert "requires a model_id but none" in str(exc.value) + + # model_id on block is default + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + assert block.model_id == "mistralai/Mixtral-8x7B-Instruct-v0.1" + + # model_id on PipelineContext gets used + self.mock_ctx.model_id = "test_model" + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert block.model_id == "test_model" + + # model_id on PipelineContext overrides + self.mock_ctx.model_id = "test_model" + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + assert block.model_id == "test_model" + + # restore state for future tests + self.mock_ctx.model_id = old_model_id + + def test_resolve_model_family(self, mock_load_config): + mock_load_config.return_value = { + "system": "{{fruit}}", + "introduction": "introduction", + "principles": "principles", + "examples": "examples", + "generation": "generation", + } + # save state for future tests + old_model_family = self.mock_ctx.model_family + + # Uses a default model_family when none specified + self.mock_ctx.model_family = None + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert block.model_family == models.DEFAULT_MODEL_FAMILY + + # model_family on block is used by defalt + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + model_family="granite", + ) + assert block.model_family == "granite" + + # model_family on PipelineContext gets used + self.mock_ctx.model_family = "mixtral" + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + ) + assert block.model_family == "mixtral" + + # model_family on PipelineContext overrides + self.mock_ctx.model_family = "mixtral" + block = LLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_path="", + output_cols=[], + model_family="granite", + ) + assert block.model_family == "mixtral" + + # restore state for future tests + self.mock_ctx.model_family = old_model_family + class TestLLMBlockBatching(unittest.TestCase): def setUp(self): @@ -481,3 +638,56 @@ def test_calls_chat_completion_api(self): assert "output" in output.column_names mock_completion.create.assert_called() assert mock_completion.create.call_args.kwargs["messages"] == "my message" + + def test_resolve_model_id(self): + # save state for future tests + old_model_id = self.mock_ctx.model_id + + # Raise an error when no model_id on block or PipelineContext + self.mock_ctx.model_id = None + with pytest.raises(BlockConfigParserError) as exc: + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + ) + assert "requires a model_id but none" in str(exc.value) + + # model_id on block is default + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + assert block.model_id == "mistralai/Mixtral-8x7B-Instruct-v0.1" + + # model_id on PipelineContext gets used + self.mock_ctx.model_id = "test_model" + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + ) + assert block.model_id == "test_model" + + # model_id on PipelineContext overrides + self.mock_ctx.model_id = "test_model" + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + assert block.model_id == "test_model" + + # restore state for future tests + self.mock_ctx.model_id = old_model_id diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c95ed486..0ede084d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -14,7 +14,7 @@ import pytest # First Party -from instructlab.sdg import Block, Pipeline, PipelineBlockError +from instructlab.sdg import Block, Pipeline, PipelineBlockError, PipelineContext ## Helpers ## @@ -259,3 +259,90 @@ def test_block_generation_error_properties_from_strings(): str(gen_err) == f"{PipelineBlockError.__name__}({block_type}/{block_name}): {inner_err}" ) + + +def test_pipeline_context_backwards_compat(): + """ + Test backwards compatibility of PipelineContext signatures + where client, model_family, model_id, num_instructions_to_generate + used to be required before SDG 0.8.x but since then are not. + """ + client = mock.MagicMock() + model_family = "granite" + model_id = "bar" + num_instructions_to_generate = 200 + save_freq = 5 + + # <= SDG 0.7.x - only passing old required positional params + ctx = PipelineContext(client, model_family, model_id, num_instructions_to_generate) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id == model_id + assert ctx.num_instructions_to_generate == num_instructions_to_generate + + # <= SDG 0.7.x - old required positional params with kwargs + ctx = PipelineContext( + client, + model_family, + model_id, + num_instructions_to_generate, + save_freq=save_freq, + ) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id == model_id + assert ctx.num_instructions_to_generate == num_instructions_to_generate + assert ctx.save_freq == save_freq + + # <= SDG 0.7.x - passing all params as kwargs + ctx = PipelineContext( + client=client, + model_family=model_family, + model_id=model_id, + num_instructions_to_generate=num_instructions_to_generate, + save_freq=save_freq, + ) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id == model_id + assert ctx.num_instructions_to_generate == num_instructions_to_generate + assert ctx.save_freq == save_freq + + # SDG 0.8.x - num_instructions_to_generate no longer required + ctx = PipelineContext(client, model_family, model_id) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id == model_id + assert ctx.num_instructions_to_generate is None + + # SDG 0.8.x - num_instructions_to_generate no longer required with kwargs + ctx = PipelineContext(client, model_family, model_id, save_freq=save_freq) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id == model_id + assert ctx.num_instructions_to_generate is None + assert ctx.save_freq == save_freq + + # SDG 0.8.x - model_id no longer required + ctx = PipelineContext(client, model_family) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id is None + + # SDG 0.8.x - model_id no longer required with kwargs + ctx = PipelineContext(client, model_family, save_freq=save_freq) + assert ctx.client == client + assert ctx.model_family == model_family + assert ctx.model_id is None + assert ctx.save_freq == save_freq + + # SDG 0.8.x - model_family no longer required + ctx = PipelineContext(client) + assert ctx.client == client + assert ctx.model_family is None + + # SDG 0.8.x - model_family no longer required with kwargs + ctx = PipelineContext(client, save_freq=save_freq) + assert ctx.client == client + assert ctx.model_family is None + assert ctx.save_freq == save_freq diff --git a/tests/testdata/custom_block.py b/tests/testdata/custom_block.py index 409d4ebc..2600d341 100644 --- a/tests/testdata/custom_block.py +++ b/tests/testdata/custom_block.py @@ -16,7 +16,7 @@ def generate(self, samples: Dataset): return samples -pipeline_context = PipelineContext(None, "mixtral", "my_model", 5) +pipeline_context = PipelineContext(None) pipeline_yaml = pathlib.Path(__file__).parent.joinpath("custom_block_pipeline.yaml") pipeline = Pipeline.from_file(pipeline_context, pipeline_yaml) input_ds = Dataset.from_list(