Skip to content

Commit

Permalink
expose max-num-tokens as a configurable field.
Browse files Browse the repository at this point in the history
max-num-tokens is a nice way to run a shorter or longer SDG run.
locally I have been modifiyng the pipeline yaml from 2048 to 512 which ends up just generating less data
exposing this to the CLI could allow power users to run different types of SDG runs!

Signed-off-by: Charlie Doern <[email protected]>
  • Loading branch information
cdoern committed Nov 7, 2024
1 parent 2b9958c commit d4ebc02
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _context_init(
save_freq: int,
batch_num_workers: Optional[int],
batch_size: Optional[int],
max_num_tokens: Optional[int] = 4096,
):
extra_kwargs = {}
if batch_size is not None:
Expand All @@ -197,6 +198,7 @@ def _context_init(
num_instructions_to_generate=num_instructions_to_generate,
checkpoint_dir=checkpoint_dir,
save_freq=save_freq,
max_num_tokens=max_num_tokens,
**extra_kwargs,
)

Expand Down Expand Up @@ -282,6 +284,7 @@ def generate_data(
pipeline: Optional[str] = "simple",
batch_size: Optional[int] = None,
checkpoint_dir: Optional[str] = None,
max_num_tokens: Optional[int] = 4096,
) -> None:
"""Generate data for training and testing a model.
Expand Down Expand Up @@ -340,6 +343,7 @@ def generate_data(
1, # save_freq
batch_size=batch_size,
batch_num_workers=num_cpus,
max_num_tokens=max_num_tokens,
)

knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init(
Expand Down
12 changes: 9 additions & 3 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
ctx,
pipe,
block_name,
max_num_tokens,
config_path,
output_cols,
model_prompt=None,
Expand All @@ -82,7 +83,10 @@ def __init__(
self.parsing_pattern = parser_kwargs.get("parsing_pattern", None)
self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None)
self.gen_kwargs = self._gen_kwargs(
gen_kwargs, model=self.ctx.model_id, temperature=0, max_tokens=4096
gen_kwargs,
model=self.ctx.model_id,
temperature=0,
max_tokens=max_num_tokens,
)
# Whether the LLM server supports a list of input prompts
# and supports the n parameter to generate n outputs per input
Expand Down Expand Up @@ -150,10 +154,10 @@ def _gen_kwargs(self, gen_kwargs, **defaults):
and gen_kwargs["n"] == "scaled"
):
gen_kwargs["n"] = self.ctx.num_instructions_to_generate
if "max_tokens" in gen_kwargs:
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
if "temperature" in gen_kwargs:
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
if "max_tokens" in gen_kwargs:
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
return gen_kwargs

def _generate(self, samples) -> list:
Expand Down Expand Up @@ -259,6 +263,7 @@ def __init__(
ctx,
pipe,
block_name,
max_num_tokens,
config_paths,
output_cols,
selector_column_name,
Expand All @@ -271,6 +276,7 @@ def __init__(
ctx,
pipe,
block_name,
max_num_tokens,
config_paths[0][0],
output_cols,
model_prompt=model_prompt,
Expand Down
18 changes: 16 additions & 2 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
central executor pool.
dataset_num_procs: The number of processes to use when performing parallel
map operations on individual datasets.
max_num_tokens: the maximum number of tokens to generate per sample.
"""

# The default batch size of 8 has been determined as a good default for
Expand All @@ -65,6 +66,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
checkpoint_dir: Optional[str] = None
save_freq: Optional[int] = 1
max_num_tokens: Optional[int] = 4096
batch_size: int = DEFAULT_BATCH_SIZE
batch_num_workers: Optional[int] = None

Expand Down Expand Up @@ -191,12 +193,24 @@ def _generate_single(self, dataset) -> Dataset:
block_name = block_prop["name"]
block_type = _lookup_block_type(block_prop["type"])
block_config = block_prop["config"]
max_num_tokens = self.ctx.max_num_tokens
drop_columns = block_prop.get("drop_columns", [])
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
block = block_type(self.ctx, self, block_name, **block_config)
if (
block_type == llmblock.LLMBlock

Check warning on line 200 in src/instructlab/sdg/pipeline.py

View workflow job for this annotation

GitHub Actions / pylint

R1714: Consider merging these comparisons with 'in' by using 'block_type in (llmblock.LLMBlock, llmblock.ConditionalLLMBlock)'. Use a set instead if elements are hashable. (consider-using-in)
or block_type == llmblock.ConditionalLLMBlock
):
block = block_type(
self.ctx,
self,
block_name,
max_num_tokens,
**block_config,
)
else:
block = block_type(self.ctx, self, block_name, **block_config)
logger.info("Running block: %s", block_name)
logger.info(dataset)

# Execute the block and wrap errors with the block name/type
dataset = block.generate(dataset)
except Exception as err:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_model_prompt_empty_string(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_tokens=2048,
config_path="",
output_cols=[],
model_prompt="",
Expand All @@ -57,6 +58,7 @@ def test_model_prompt_none(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_tokens=2048,
config_path="",
output_cols=[],
model_prompt=None, # Or simply omit model_prompt as it defaults to None
Expand All @@ -76,6 +78,7 @@ def test_model_prompt_none(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_tokens=2048,
config_path="",
output_cols=[],
model_prompt="FOO {prompt} BAR",
Expand Down

0 comments on commit d4ebc02

Please sign in to comment.