Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make everything but client optional on PipelineContext #518

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
## Upcoming v0.8.x

### Features

### LLMBlocks can now specify `model_family` or `model_id` in their config

Each `LLMBlock` in a `Pipeline` can now specify `model_family` or `model_id` in their yaml configuration to set the values to use for these blocks, as opposed to setting this for the entire `Pipeline` in the `PipelineContext` object. This is useful for the cases where multiple `LLMBlocks` exist in the same `Pipeline` where each one uses a different model.

#### Fewer required parameters for `PipelineContext`

The parameters `model_family`, `model_id`, and `num_instructions_to_generate` are no longer required in `PipelineContext` objects. They used to be required, and if passed in will still get used as before. However, they can now be omitted if your `Pipeline` contains no `LLMBlock` entries or if your `LLMBlock` config specifies these values in the `Pipeline` yaml.

## v0.7.0

### Features
Expand Down
5 changes: 5 additions & 0 deletions docs/examples/multiple_llms/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Multiple LLMs in a single Pipeline

To create a `Pipeline` that uses multiple different LLMs for inference, each `LLMBlock` in the `Pipeline` yaml should specify a `model_family` (to indicate which set of prompts to use) and a `model_id` (corresponding to the id of this model in the deployed inference server). For more control over the prompting, a direct `model_prompt` may be used in place of `model_family` to specify the exact prompt to use as opposed to choosing a generic one for that model family.

See `pipeline.yaml` in this directory for an example of a Pipeline that calls into 3 separate LLMs.
34 changes: 34 additions & 0 deletions docs/examples/multiple_llms/pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
version: "1.0"
blocks:
- name: model_one
type: LLMBlock
config:
model_family: mixtral
model_id: Mixtral-8x7B-Instruct-v0.1
config_path: model_one_config.yaml
output_cols:
- column_one
gen_kwargs:
max_tokens: 2048

- name: model_two
type: LLMBlock
config:
model_family: granite
model_id: granite-7b-lab
config_path: model_two_config.yaml
output_cols:
- column_two
gen_kwargs:
max_tokens: 512

- name: model_three
type: LLMBlock
config:
model_id: granite-7b-lab
model_prompt: <s> [INST] {prompt} [/INST]
config_path: model_three_config.yaml
output_cols:
- column_three
gen_kwargs:
max_tokens: 5
1 change: 1 addition & 0 deletions scripts/validate_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def main():
schema = json.load(file)

yaml_files = glob.glob("src/instructlab/sdg/pipelines/**/*.yaml", recursive=True)
yaml_files.extend(glob.glob("docs/examples/**/pipeline.yaml", recursive=True))
all_valid = True
for yaml_file in yaml_files:
print("=======================================================")
Expand Down
57 changes: 53 additions & 4 deletions src/instructlab/sdg/blocks/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Import prompts to register default chat templates
from .. import prompts as default_prompts # pylint: disable=unused-import
from ..registry import BlockRegistry, PromptRegistry
from ..utils import models
from .block import Block, BlockConfigParserError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,6 +63,33 @@ def template_from_struct_and_config(struct, config):
return PromptRegistry.template_from_string(struct.format(**filtered_config))


def _resolve_model_id(model_id, ctx_model_id, block):
# If a model id was passed in the PipelineContext, use that
if ctx_model_id:
return ctx_model_id

# If we have no model id at all, raise an error
if not model_id:
raise BlockConfigParserError(
f"{type(block).__name__} {block.block_name} requires a model_id but none was specified in the block config nor passed via the PipelineContext"
)

# Otherwise fallback to the model_id specified in the block config
return model_id


def _resolve_model_family(model_family, ctx_model_family):
# If a model family was passed in the PipelineContext, use that
if ctx_model_family:
return ctx_model_family

# Otherwise fallback to the model_family specified in the block config
bbrowning marked this conversation as resolved.
Show resolved Hide resolved
# This could be None, but that's ok for model_family because when the
# actual prompt gets selected it falls back to merlinite if nothing
# else was given.
return model_family


# This is part of the public API.
@BlockRegistry.register("LLMBlock")
# pylint: disable=dangerous-default-value
Expand All @@ -74,6 +102,8 @@ def __init__(
block_name,
config_path,
output_cols,
model_id=None,
model_family=None,
model_prompt=None,
gen_kwargs={},
parser_kwargs={},
Expand All @@ -87,6 +117,11 @@ def __init__(
self.prompt_template = template_from_struct_and_config(
self.prompt_struct, self.block_config
)
self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self)
self.model_family = models.get_model_family(
_resolve_model_family(model_family, self.ctx.model_family),
self.model_id,
)
self.model_prompt = model_prompt
self.output_cols = output_cols
self.batch_params = batch_kwargs
Expand All @@ -104,14 +139,14 @@ def __init__(
self.gen_kwargs = self._gen_kwargs(
max_num_token_override,
gen_kwargs,
model=self.ctx.model_id,
model=self.model_id,
temperature=0,
max_tokens=DEFAULT_MAX_NUM_TOKENS,
)
# Whether the LLM server supports a list of input prompts
# and supports the n parameter to generate n outputs per input
self.server_supports_batched = server_supports_batched(
self.ctx.client, self.ctx.model_id
self.ctx.client, self.model_id
)

def _parse(self, generated_string) -> dict:
Expand Down Expand Up @@ -160,7 +195,7 @@ def _format_prompt(self, sample: Dict) -> str:

model_prompt = None
if self.model_prompt is None:
model_prompt = PromptRegistry.get_template(self.ctx.model_family)
model_prompt = PromptRegistry.get_template(self.model_family)
elif self.model_prompt:
model_prompt = PromptRegistry.template_from_string(self.model_prompt)
else:
Expand All @@ -183,6 +218,10 @@ def _gen_kwargs(self, max_num_token_override, gen_kwargs, **defaults):
and isinstance(gen_kwargs["n"], str)
and gen_kwargs["n"] == "scaled"
):
if not self.ctx.num_instructions_to_generate:
raise BlockConfigParserError(
f"""LLMBlock {self.block_name} has a gen_kwargs["n"] value of "scaled" but num_instructions_to_generate was not set in the PipelineContext"""
)
gen_kwargs["n"] = self.ctx.num_instructions_to_generate
if "temperature" in gen_kwargs:
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
Expand Down Expand Up @@ -285,6 +324,8 @@ def __init__(
config_paths,
output_cols,
selector_column_name,
model_id=None,
model_family=None,
model_prompt=None,
gen_kwargs={},
parser_kwargs={},
Expand All @@ -305,6 +346,8 @@ def __init__(
block_name,
config_paths[0][0],
output_cols,
model_id=model_id,
model_family=model_family,
model_prompt=model_prompt,
gen_kwargs=gen_kwargs,
parser_kwargs=parser_kwargs,
Expand Down Expand Up @@ -360,6 +403,8 @@ def __init__(
block_name,
config_path,
output_cols,
model_id=None,
model_family=None,
model_prompt=None,
gen_kwargs={},
parser_kwargs={},
Expand All @@ -371,6 +416,8 @@ def __init__(
block_name,
config_path,
output_cols,
model_id=model_id,
model_family=model_family,
model_prompt=model_prompt,
gen_kwargs=gen_kwargs,
parser_kwargs=parser_kwargs,
Expand Down Expand Up @@ -466,14 +513,16 @@ def __init__(
block_name,
input_col,
output_col,
model_id=None,
gen_kwargs={},
) -> None:
super().__init__(ctx, pipe, block_name)
self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self)
self.input_col = input_col
self.output_col = output_col
self.gen_kwargs = self._gen_kwargs(
gen_kwargs,
model=self.ctx.model_id,
model=self.model_id,
temperature=0,
max_tokens=DEFAULT_MAX_NUM_TOKENS,
)
Expand Down
11 changes: 4 additions & 7 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# First Party
from instructlab.sdg.checkpointing import Checkpointer
from instructlab.sdg.utils import models, pandas
from instructlab.sdg.utils import pandas

# Local
from .blocks import llmblock
Expand Down Expand Up @@ -61,19 +61,16 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
DEFAULT_DATASET_NUM_PROCS = 8

client: OpenAI
model_family: str
model_id: str
num_instructions_to_generate: int
model_family: Optional[str] = None
model_id: Optional[str] = None
num_instructions_to_generate: Optional[int] = None
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
checkpoint_dir: Optional[str] = None
save_freq: Optional[int] = 1
max_num_tokens: Optional[int] = llmblock.DEFAULT_MAX_NUM_TOKENS
batch_size: int = DEFAULT_BATCH_SIZE
batch_num_workers: Optional[int] = None

def __post_init__(self):
self.model_family = models.get_model_family(self.model_family, self.model_id)

@property
def batching_enabled(self) -> bool:
"""Batching is enabled IFF the batch size is specified and the number of
Expand Down
12 changes: 12 additions & 0 deletions src/instructlab/sdg/pipelines/schema/v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@
"type": "string"
}
},
"model_id": {
"type": "string"
},
"model_family": {
"type": "string"
},
"model_prompt": {
"type": "string"
},
Expand Down Expand Up @@ -177,6 +183,12 @@
"type": "string"
}
},
"model_id": {
"type": "string"
},
"model_family": {
"type": "string"
},
"model_prompt": {
"type": "string"
},
Expand Down
1 change: 1 addition & 0 deletions tests/functional/test_granular_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_granular_api_end_to_end(self):
input_dir=preprocessed_dir,
output_dir=generated_dir,
pipeline=pipeline_dir,
model_id="mock",
num_cpus=1, # Test is faster running on a single CPU vs forking
batch_size=0, # Disable batch for tiny dataset and fastest test
)
Expand Down
Loading
Loading