diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index d7bfd9f2..8c1bda54 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -184,6 +184,7 @@ def _context_init( save_freq: int, batch_num_workers: Optional[int], batch_size: Optional[int], + max_num_tokens: Optional[int] = 4096, ): extra_kwargs = {} if batch_size is not None: @@ -197,6 +198,7 @@ def _context_init( num_instructions_to_generate=num_instructions_to_generate, checkpoint_dir=checkpoint_dir, save_freq=save_freq, + max_num_tokens=max_num_tokens, **extra_kwargs, ) @@ -282,6 +284,7 @@ def generate_data( pipeline: Optional[str] = "simple", batch_size: Optional[int] = None, checkpoint_dir: Optional[str] = None, + max_num_tokens: Optional[int] = 4096, ) -> None: """Generate data for training and testing a model. @@ -340,6 +343,7 @@ def generate_data( 1, # save_freq batch_size=batch_size, batch_num_workers=num_cpus, + max_num_tokens=max_num_tokens, ) knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init( diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/llmblock.py index 2ddd30c2..9d81984a 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/llmblock.py @@ -62,6 +62,7 @@ def __init__( ctx, pipe, block_name, + max_num_tokens, config_path, output_cols, model_prompt=None, @@ -82,7 +83,10 @@ def __init__( self.parsing_pattern = parser_kwargs.get("parsing_pattern", None) self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None) self.gen_kwargs = self._gen_kwargs( - gen_kwargs, model=self.ctx.model_id, temperature=0, max_tokens=4096 + gen_kwargs, + model=self.ctx.model_id, + temperature=0, + max_tokens=max_num_tokens, ) # Whether the LLM server supports a list of input prompts # and supports the n parameter to generate n outputs per input @@ -150,10 +154,10 @@ def _gen_kwargs(self, gen_kwargs, **defaults): and gen_kwargs["n"] == "scaled" ): gen_kwargs["n"] = self.ctx.num_instructions_to_generate - if "max_tokens" in gen_kwargs: - gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"]) if "temperature" in gen_kwargs: gen_kwargs["temperature"] = float(gen_kwargs["temperature"]) + if "max_tokens" in gen_kwargs: + gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"]) return gen_kwargs def _generate(self, samples) -> list: @@ -259,6 +263,7 @@ def __init__( ctx, pipe, block_name, + max_num_tokens, config_paths, output_cols, selector_column_name, @@ -271,6 +276,7 @@ def __init__( ctx, pipe, block_name, + max_num_tokens, config_paths[0][0], output_cols, model_prompt=model_prompt, diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index f349e5d5..0ae13f54 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -48,6 +48,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes central executor pool. dataset_num_procs: The number of processes to use when performing parallel map operations on individual datasets. + max_num_tokens: the maximum number of tokens to generate per sample. """ # The default batch size of 8 has been determined as a good default for @@ -65,6 +66,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS checkpoint_dir: Optional[str] = None save_freq: Optional[int] = 1 + max_num_tokens: Optional[int] = 4096 batch_size: int = DEFAULT_BATCH_SIZE batch_num_workers: Optional[int] = None @@ -191,12 +193,24 @@ def _generate_single(self, dataset) -> Dataset: block_name = block_prop["name"] block_type = _lookup_block_type(block_prop["type"]) block_config = block_prop["config"] + max_num_tokens = self.ctx.max_num_tokens drop_columns = block_prop.get("drop_columns", []) drop_duplicates_cols = block_prop.get("drop_duplicates", False) - block = block_type(self.ctx, self, block_name, **block_config) + if ( + block_type == llmblock.LLMBlock + or block_type == llmblock.ConditionalLLMBlock + ): + block = block_type( + self.ctx, + self, + block_name, + max_num_tokens, + **block_config, + ) + else: + block = block_type(self.ctx, self, block_name, **block_config) logger.info("Running block: %s", block_name) logger.info(dataset) - # Execute the block and wrap errors with the block name/type dataset = block.generate(dataset) except Exception as err: diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index f7835f32..c9756896 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -37,6 +37,7 @@ def test_model_prompt_empty_string(self, mock_load_config): ctx=self.mock_ctx, pipe=self.mock_pipe, block_name="test_block", + max_num_tokens=2048, config_path="", output_cols=[], model_prompt="", @@ -57,6 +58,7 @@ def test_model_prompt_none(self, mock_load_config): ctx=self.mock_ctx, pipe=self.mock_pipe, block_name="test_block", + max_num_tokens=2048, config_path="", output_cols=[], model_prompt=None, # Or simply omit model_prompt as it defaults to None @@ -76,6 +78,7 @@ def test_model_prompt_none(self, mock_load_config): ctx=self.mock_ctx, pipe=self.mock_pipe, block_name="test_block", + max_num_tokens=2048, config_path="", output_cols=[], model_prompt="FOO {prompt} BAR",