From 85760506bc47e861b2f0376748da424f8ea9d6f1 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Mon, 6 Jan 2025 14:37:58 -0500 Subject: [PATCH] Implement LLMMessagesBlock This was just a commented out stub before, ported over from the research prototypes. This uncomments it, ports things to work within the current structure of InstructLab, and adds a few tests to verify the logic within the block. Fixes #414 Signed-off-by: Ben Browning --- src/instructlab/sdg/blocks/llmblock.py | 84 +++++++++++++++----------- tests/test_llmblock.py | 69 +++++++++++++++++---- 2 files changed, 107 insertions(+), 46 deletions(-) diff --git a/src/instructlab/sdg/blocks/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py index 89d9a27e..b940954d 100644 --- a/src/instructlab/sdg/blocks/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -457,48 +457,64 @@ def __init__( # This is part of the public API. @BlockRegistry.register("LLMMessagesBlock") -class LLMMessagesBlock(LLMBlock): +class LLMMessagesBlock(Block): def __init__( self, ctx, pipe, block_name, - config_path, - output_cols, - model_prompt=None, + input_col, + output_col, gen_kwargs={}, - parser_kwargs={}, - batch_kwargs={}, ) -> None: - super().__init__( - ctx, - pipe, - block_name, - config_path, - output_cols, - model_prompt=model_prompt, - gen_kwargs=gen_kwargs, - parser_kwargs=parser_kwargs, - batch_kwargs=batch_kwargs, + super().__init__(ctx, pipe, block_name) + self.input_col = input_col + self.output_col = output_col + self.gen_kwargs = self._gen_kwargs( + gen_kwargs, + model=self.ctx.model_id, + temperature=0, + max_tokens=DEFAULT_MAX_NUM_TOKENS, ) - # def _generate(self, samples) -> list: - # generate_args = {**self.defaults, **gen_kwargs} - - # if "n" in generate_args and generate_args.get("temperature", 0) <= 0: - # generate_args["temperature"] = 0.7 - # logger.warning( - # "Temperature should be greater than 0 for n > 1, setting temperature to 0.7" - # ) + def _gen_kwargs(self, gen_kwargs, **defaults): + gen_kwargs = {**defaults, **gen_kwargs} + if "temperature" in gen_kwargs: + gen_kwargs["temperature"] = float(gen_kwargs["temperature"]) + if ( + "n" in gen_kwargs + and gen_kwargs["n"] > 1 + and gen_kwargs.get("temperature", 0) <= 0 + ): + gen_kwargs["temperature"] = 0.7 + logger.warning( + "Temperature should be greater than 0 for n > 1, setting temperature to 0.7" + ) + return gen_kwargs - # messages = samples[self.input_col] + def _generate(self, samples) -> list: + messages = samples[self.input_col] + logger.debug("STARTING GENERATION FOR LLMMessagesBlock") + logger.debug(f"Generation arguments: {self.gen_kwargs}") + results = [] + progress_bar = tqdm( + range(len(samples)), desc=f"{self.block_name} Chat Completion Generation" + ) + n = self.gen_kwargs.get("n", 1) + for message in messages: + logger.debug(f"CREATING CHAT COMPLETION FOR MESSAGE: {message}") + responses = self.ctx.client.chat.completions.create( + messages=message, **self.gen_kwargs + ) + if n > 1: + results.append([choice.message.content for choice in responses.choices]) + else: + results.append(responses.choices[0].message.content) + progress_bar.update(n) + return results - # results = [] - # n = gen_kwargs.get("n", 1) - # for message in messages: - # responses = self.client.chat.completions.create(messages=message, **generate_args) - # if n > 1: - # results.append([choice.message.content for choice in responses.choices]) - # else: - # results.append(responses.choices[0].message.content) - # return results + def generate(self, samples: Dataset) -> Dataset: + outputs = self._generate(samples) + logger.debug("Generated outputs: %s", outputs) + samples = samples.add_column(self.output_col, outputs) + return samples diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index f9d78177..2205ab5e 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -336,28 +336,73 @@ def test_constructor_works(self, mock_load_config): assert block is not None -@patch("src.instructlab.sdg.blocks.block.Block._load_config") class TestLLMMessagesBlock(unittest.TestCase): def setUp(self): self.mock_ctx = MagicMock() self.mock_ctx.model_family = "mixtral" self.mock_ctx.model_id = "test_model" self.mock_pipe = MagicMock() - self.config_return_value = { - "system": "{{fruit}}", - "introduction": "introduction", - "principles": "principles", - "examples": "examples", - "generation": "generation", - } + self.mock_client = MagicMock() + self.mock_ctx.client = self.mock_client - def test_constructor_works(self, mock_load_config): - mock_load_config.return_value = self.config_return_value + def test_constructor_works(self): block = LLMMessagesBlock( ctx=self.mock_ctx, pipe=self.mock_pipe, block_name="gen_knowledge", - config_path="", - output_cols=[], + input_col="messages", + output_col="output", ) assert block is not None + + def test_temperature_validation(self): + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + gen_kwargs={"n": 5, "temperature": 0}, + ) + assert block.gen_kwargs["temperature"] != 0 + + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + gen_kwargs={"n": 1, "temperature": 0}, + ) + assert block.gen_kwargs["temperature"] == 0 + + def test_calls_chat_completion_api(self): + # Mock the OpenAI client so we don't actually hit a server here + mock_choice = MagicMock() + mock_choice.message = MagicMock() + mock_choice.message.content = "generated response" + mock_completion_resp = MagicMock() + mock_completion_resp.choices = [mock_choice] + mock_completion = MagicMock() + mock_completion.create = MagicMock() + mock_completion.create.return_value = mock_completion_resp + mock_chat = MagicMock() + mock_chat.completions = mock_completion + self.mock_client.chat = mock_chat + block = LLMMessagesBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + input_col="messages", + output_col="output", + gen_kwargs={"n": 1, "temperature": 0}, + ) + samples = Dataset.from_dict( + {"messages": ["my message"]}, + features=Features({"messages": Value("string")}), + ) + output = block.generate(samples) + assert len(output) == 1 + assert "output" in output.column_names + mock_completion.create.assert_called() + assert mock_completion.create.call_args.kwargs["messages"] == "my message"