Skip to content

Commit

Permalink
Make everything but client optional on PipelineContext
Browse files Browse the repository at this point in the history
This makes model_family, model_id, and num_instructions_to_generate
optional on PipelineContext instead of required. This should be a
backwards-compatible change, and includes tests that verify passing
the previously-required parameters into PipelineContext work as
expected. This is a move towards supporting and expecting Pipelines
that have LLMBlocks that may use different model families or different
model ids.

If a Block's yaml config specifies a model_family or model_id and none
is set on the PipelineContext, then the Block's values will get
used. However, if these are set on the PipelineContext, they're
treated as overrides and will override the values for all Blocks in
this Pipeline. This behavior emulates the previous behavior, but we
should steer users away from ever setting model_id or model_family on
a PipelineContext if the Pipeline contains multiple LLMBlock entries.

Fixes #511

Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
bbrowning committed Jan 29, 2025
1 parent d4a7136 commit c8145f8
Show file tree
Hide file tree
Showing 11 changed files with 421 additions and 13 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
## Upcoming v0.8.x

### Features

### LLMBlocks can now specify `model_family` or `model_id` in their config

Each `LLMBlock` in a `Pipeline` can now specify `model_family` or `model_id` in their yaml configuration to set the values to use for these blocks, as opposed to setting this for the entire `Pipeline` in the `PipelineContext` object. This is useful for the cases where multiple `LLMBlocks` exist in the same `Pipeline` where each one uses a different model.

#### Fewer required parameters for `PipelineContext`

The parameters `model_family`, `model_id`, and `num_instructions_to_generate` are no longer required in `PipelineContext` objects. They used to be required, and if passed in will still get used as before. However, they can now be omitted if your `Pipeline` contains no `LLMBlock` entries or if your `LLMBlock` config specifies these values in the `Pipeline` yaml.

## v0.7.0

### Features
Expand Down
5 changes: 5 additions & 0 deletions docs/examples/multiple_llms/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Multiple LLMs in a single Pipeline

To create a `Pipeline` that uses multiple different LLMs for inference, each `LLMBlock` in the `Pipeline` yaml should specify a `model_family` (to indicate which set of prompts to use) and a `model_id` (corresponding to the id of this model in the deployed inference server). For more control over the prompting, a direct `model_prompt` may be used in place of `model_family` to specify the exact prompt to use as opposed to choosing a generic one for that model family.

See `pipeline.yaml` in this directory for an example of a Pipeline that calls into 3 separate LLMs.
34 changes: 34 additions & 0 deletions docs/examples/multiple_llms/pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
version: "1.0"
blocks:
- name: model_one
type: LLMBlock
config:
model_family: mixtral
model_id: Mixtral-8x7B-Instruct-v0.1
config_path: model_one_config.yaml
output_cols:
- column_one
gen_kwargs:
max_tokens: 2048

- name: model_two
type: LLMBlock
config:
model_family: granite
model_id: granite-7b-lab
config_path: model_two_config.yaml
output_cols:
- column_two
gen_kwargs:
max_tokens: 512

- name: model_three
type: LLMBlock
config:
model_id: granite-7b-lab
model_prompt: <s> [INST] {prompt} [/INST]
config_path: model_three_config.yaml
output_cols:
- column_three
gen_kwargs:
max_tokens: 5
1 change: 1 addition & 0 deletions scripts/validate_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def main():
schema = json.load(file)

yaml_files = glob.glob("src/instructlab/sdg/pipelines/**/*.yaml", recursive=True)
yaml_files.extend(glob.glob("docs/examples/**/pipeline.yaml", recursive=True))
all_valid = True
for yaml_file in yaml_files:
print("=======================================================")
Expand Down
57 changes: 53 additions & 4 deletions src/instructlab/sdg/blocks/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Import prompts to register default chat templates
from .. import prompts as default_prompts # pylint: disable=unused-import
from ..registry import BlockRegistry, PromptRegistry
from ..utils import models
from .block import Block, BlockConfigParserError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,6 +63,33 @@ def template_from_struct_and_config(struct, config):
return PromptRegistry.template_from_string(struct.format(**filtered_config))


def _resolve_model_id(model_id, ctx_model_id, block):
# If a model id was passed in the PipelineContext, use that
if ctx_model_id:
return ctx_model_id

# If we have no model id at all, raise an error
if not model_id:
raise BlockConfigParserError(
f"{type(block).__name__} {block.block_name} requires a model_id but none was specified in the block config nor passed via the PipelineContext"
)

# Otherwise fallback to the model_id specified in the block config
return model_id


def _resolve_model_family(model_family, ctx_model_family):
# If a model family was passed in the PipelineContext, use that
if ctx_model_family:
return ctx_model_family

# Otherwise fallback to the model_family specified in the block config
# This could be None, but that's ok for model_family because when the
# actual prompt gets selected it falls back to merlinite if nothing
# else was given.
return model_family


# This is part of the public API.
@BlockRegistry.register("LLMBlock")
# pylint: disable=dangerous-default-value
Expand All @@ -74,6 +102,8 @@ def __init__(
block_name,
config_path,
output_cols,
model_id=None,
model_family=None,
model_prompt=None,
gen_kwargs={},
parser_kwargs={},
Expand All @@ -87,6 +117,11 @@ def __init__(
self.prompt_template = template_from_struct_and_config(
self.prompt_struct, self.block_config
)
self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self)
self.model_family = models.get_model_family(
_resolve_model_family(model_family, self.ctx.model_family),
self.model_id,
)
self.model_prompt = model_prompt
self.output_cols = output_cols
self.batch_params = batch_kwargs
Expand All @@ -104,14 +139,14 @@ def __init__(
self.gen_kwargs = self._gen_kwargs(
max_num_token_override,
gen_kwargs,
model=self.ctx.model_id,
model=self.model_id,
temperature=0,
max_tokens=DEFAULT_MAX_NUM_TOKENS,
)
# Whether the LLM server supports a list of input prompts
# and supports the n parameter to generate n outputs per input
self.server_supports_batched = server_supports_batched(
self.ctx.client, self.ctx.model_id
self.ctx.client, self.model_id
)

def _parse(self, generated_string) -> dict:
Expand Down Expand Up @@ -160,7 +195,7 @@ def _format_prompt(self, sample: Dict) -> str:

model_prompt = None
if self.model_prompt is None:
model_prompt = PromptRegistry.get_template(self.ctx.model_family)
model_prompt = PromptRegistry.get_template(self.model_family)
elif self.model_prompt:
model_prompt = PromptRegistry.template_from_string(self.model_prompt)
else:
Expand All @@ -183,6 +218,10 @@ def _gen_kwargs(self, max_num_token_override, gen_kwargs, **defaults):
and isinstance(gen_kwargs["n"], str)
and gen_kwargs["n"] == "scaled"
):
if not self.ctx.num_instructions_to_generate:
raise BlockConfigParserError(
f"""LLMBlock {self.block_name} has a gen_kwargs["n"] value of "scaled" but num_instructions_to_generate was not set in the PipelineContext"""
)
gen_kwargs["n"] = self.ctx.num_instructions_to_generate
if "temperature" in gen_kwargs:
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
Expand Down Expand Up @@ -285,6 +324,8 @@ def __init__(
config_paths,
output_cols,
selector_column_name,
model_id=None,
model_family=None,
model_prompt=None,
gen_kwargs={},
parser_kwargs={},
Expand All @@ -305,6 +346,8 @@ def __init__(
block_name,
config_paths[0][0],
output_cols,
model_id=model_id,
model_family=model_family,
model_prompt=model_prompt,
gen_kwargs=gen_kwargs,
parser_kwargs=parser_kwargs,
Expand Down Expand Up @@ -360,6 +403,8 @@ def __init__(
block_name,
config_path,
output_cols,
model_id=None,
model_family=None,
model_prompt=None,
gen_kwargs={},
parser_kwargs={},
Expand All @@ -371,6 +416,8 @@ def __init__(
block_name,
config_path,
output_cols,
model_id=model_id,
model_family=model_family,
model_prompt=model_prompt,
gen_kwargs=gen_kwargs,
parser_kwargs=parser_kwargs,
Expand Down Expand Up @@ -466,14 +513,16 @@ def __init__(
block_name,
input_col,
output_col,
model_id=None,
gen_kwargs={},
) -> None:
super().__init__(ctx, pipe, block_name)
self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self)
self.input_col = input_col
self.output_col = output_col
self.gen_kwargs = self._gen_kwargs(
gen_kwargs,
model=self.ctx.model_id,
model=self.model_id,
temperature=0,
max_tokens=DEFAULT_MAX_NUM_TOKENS,
)
Expand Down
11 changes: 4 additions & 7 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# First Party
from instructlab.sdg.checkpointing import Checkpointer
from instructlab.sdg.utils import models, pandas
from instructlab.sdg.utils import pandas

# Local
from .blocks import llmblock
Expand Down Expand Up @@ -61,19 +61,16 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
DEFAULT_DATASET_NUM_PROCS = 8

client: OpenAI
model_family: str
model_id: str
num_instructions_to_generate: int
model_family: Optional[str] = None
model_id: Optional[str] = None
num_instructions_to_generate: Optional[int] = None
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
checkpoint_dir: Optional[str] = None
save_freq: Optional[int] = 1
max_num_tokens: Optional[int] = llmblock.DEFAULT_MAX_NUM_TOKENS
batch_size: int = DEFAULT_BATCH_SIZE
batch_num_workers: Optional[int] = None

def __post_init__(self):
self.model_family = models.get_model_family(self.model_family, self.model_id)

@property
def batching_enabled(self) -> bool:
"""Batching is enabled IFF the batch size is specified and the number of
Expand Down
12 changes: 12 additions & 0 deletions src/instructlab/sdg/pipelines/schema/v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@
"type": "string"
}
},
"model_id": {
"type": "string"
},
"model_family": {
"type": "string"
},
"model_prompt": {
"type": "string"
},
Expand Down Expand Up @@ -177,6 +183,12 @@
"type": "string"
}
},
"model_id": {
"type": "string"
},
"model_family": {
"type": "string"
},
"model_prompt": {
"type": "string"
},
Expand Down
1 change: 1 addition & 0 deletions tests/functional/test_granular_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_granular_api_end_to_end(self):
input_dir=preprocessed_dir,
output_dir=generated_dir,
pipeline=pipeline_dir,
model_id="mock",
num_cpus=1, # Test is faster running on a single CPU vs forking
batch_size=0, # Disable batch for tiny dataset and fastest test
)
Expand Down
Loading

0 comments on commit c8145f8

Please sign in to comment.