Skip to content

Commit

Permalink
Implement LLMMessagesBlock
Browse files Browse the repository at this point in the history
This was just a commented out stub before, ported over from the
research prototypes. This uncomments it, ports things to work within
the current structure of InstructLab, and adds a few tests to verify
the logic within the block.

Fixes #414

Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
bbrowning committed Jan 6, 2025
1 parent 02ccaef commit 8576050
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 46 deletions.
84 changes: 50 additions & 34 deletions src/instructlab/sdg/blocks/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,48 +457,64 @@ def __init__(

# This is part of the public API.
@BlockRegistry.register("LLMMessagesBlock")
class LLMMessagesBlock(LLMBlock):
class LLMMessagesBlock(Block):
def __init__(
self,
ctx,
pipe,
block_name,
config_path,
output_cols,
model_prompt=None,
input_col,
output_col,
gen_kwargs={},
parser_kwargs={},
batch_kwargs={},
) -> None:
super().__init__(
ctx,
pipe,
block_name,
config_path,
output_cols,
model_prompt=model_prompt,
gen_kwargs=gen_kwargs,
parser_kwargs=parser_kwargs,
batch_kwargs=batch_kwargs,
super().__init__(ctx, pipe, block_name)
self.input_col = input_col
self.output_col = output_col
self.gen_kwargs = self._gen_kwargs(
gen_kwargs,
model=self.ctx.model_id,
temperature=0,
max_tokens=DEFAULT_MAX_NUM_TOKENS,
)

# def _generate(self, samples) -> list:
# generate_args = {**self.defaults, **gen_kwargs}

# if "n" in generate_args and generate_args.get("temperature", 0) <= 0:
# generate_args["temperature"] = 0.7
# logger.warning(
# "Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
# )
def _gen_kwargs(self, gen_kwargs, **defaults):
gen_kwargs = {**defaults, **gen_kwargs}
if "temperature" in gen_kwargs:
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
if (
"n" in gen_kwargs
and gen_kwargs["n"] > 1
and gen_kwargs.get("temperature", 0) <= 0
):
gen_kwargs["temperature"] = 0.7
logger.warning(
"Temperature should be greater than 0 for n > 1, setting temperature to 0.7"
)
return gen_kwargs

# messages = samples[self.input_col]
def _generate(self, samples) -> list:
messages = samples[self.input_col]
logger.debug("STARTING GENERATION FOR LLMMessagesBlock")
logger.debug(f"Generation arguments: {self.gen_kwargs}")
results = []
progress_bar = tqdm(
range(len(samples)), desc=f"{self.block_name} Chat Completion Generation"
)
n = self.gen_kwargs.get("n", 1)
for message in messages:
logger.debug(f"CREATING CHAT COMPLETION FOR MESSAGE: {message}")
responses = self.ctx.client.chat.completions.create(
messages=message, **self.gen_kwargs
)
if n > 1:
results.append([choice.message.content for choice in responses.choices])
else:
results.append(responses.choices[0].message.content)
progress_bar.update(n)
return results

# results = []
# n = gen_kwargs.get("n", 1)
# for message in messages:
# responses = self.client.chat.completions.create(messages=message, **generate_args)
# if n > 1:
# results.append([choice.message.content for choice in responses.choices])
# else:
# results.append(responses.choices[0].message.content)
# return results
def generate(self, samples: Dataset) -> Dataset:
outputs = self._generate(samples)
logger.debug("Generated outputs: %s", outputs)
samples = samples.add_column(self.output_col, outputs)
return samples
69 changes: 57 additions & 12 deletions tests/test_llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,28 +336,73 @@ def test_constructor_works(self, mock_load_config):
assert block is not None


@patch("src.instructlab.sdg.blocks.block.Block._load_config")
class TestLLMMessagesBlock(unittest.TestCase):
def setUp(self):
self.mock_ctx = MagicMock()
self.mock_ctx.model_family = "mixtral"
self.mock_ctx.model_id = "test_model"
self.mock_pipe = MagicMock()
self.config_return_value = {
"system": "{{fruit}}",
"introduction": "introduction",
"principles": "principles",
"examples": "examples",
"generation": "generation",
}
self.mock_client = MagicMock()
self.mock_ctx.client = self.mock_client

def test_constructor_works(self, mock_load_config):
mock_load_config.return_value = self.config_return_value
def test_constructor_works(self):
block = LLMMessagesBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="gen_knowledge",
config_path="",
output_cols=[],
input_col="messages",
output_col="output",
)
assert block is not None

def test_temperature_validation(self):
block = LLMMessagesBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="gen_knowledge",
input_col="messages",
output_col="output",
gen_kwargs={"n": 5, "temperature": 0},
)
assert block.gen_kwargs["temperature"] != 0

block = LLMMessagesBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="gen_knowledge",
input_col="messages",
output_col="output",
gen_kwargs={"n": 1, "temperature": 0},
)
assert block.gen_kwargs["temperature"] == 0

def test_calls_chat_completion_api(self):
# Mock the OpenAI client so we don't actually hit a server here
mock_choice = MagicMock()
mock_choice.message = MagicMock()
mock_choice.message.content = "generated response"
mock_completion_resp = MagicMock()
mock_completion_resp.choices = [mock_choice]
mock_completion = MagicMock()
mock_completion.create = MagicMock()
mock_completion.create.return_value = mock_completion_resp
mock_chat = MagicMock()
mock_chat.completions = mock_completion
self.mock_client.chat = mock_chat
block = LLMMessagesBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="gen_knowledge",
input_col="messages",
output_col="output",
gen_kwargs={"n": 1, "temperature": 0},
)
samples = Dataset.from_dict(
{"messages": ["my message"]},
features=Features({"messages": Value("string")}),
)
output = block.generate(samples)
assert len(output) == 1
assert "output" in output.column_names
mock_completion.create.assert_called()
assert mock_completion.create.call_args.kwargs["messages"] == "my message"

0 comments on commit 8576050

Please sign in to comment.