Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ“š Adding Knowledge blocks #47

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions scripts/test_knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@

samples = [
{
"question_1": "what is the location of the tubal tonsils?",
"response_1": "The location of the tubal tonsils is the roof of the pharynx.",
"question_2": "How long does the adenoid grow?",
"icl_query_1": "what is the location of the tubal tonsils?",
"icl_response_1": "The location of the tubal tonsils is the roof of the pharynx.",
"icl_query_2": "How long does the adenoid grow?",
"task_description": "Teaching about human anatomy, specifically tonsils",
"response_2": "The adenoid grows until the age of 5, starts to shrink at the age of 7 and becomes small in adulthood.",
"question_3": "What is the immune systems first line of defense against ingested or inhaled foreign pathogens?",
"response_3": "The tonsils are the immune systems first line of defense.",
"icl_response_2": "The adenoid grows until the age of 5, starts to shrink at the age of 7 and becomes small in adulthood.",
"icl_query_3": "What is the immune systems first line of defense against ingested or inhaled foreign pathogens?",
"icl_response_3": "The tonsils are the immune systems first line of defense.",
"document": "The **tonsils** are a set of lymphoid organs facing into the aerodigestive tract, which is known as Waldeyer's tonsillar ring and consists of the adenoid tonsil or pharyngeal tonsil, two tubal tonsils, two palatine tonsils, and the lingual tonsils. These organs play an important role in the immune system. When used unqualified, the term most commonly refers specifically to the palatine tonsils, which are two lymphoid organs situated at either side of the back of the human throat. The palatine tonsils and the adenoid tonsil are organs consisting of lymphoepithelial tissue located near the oropharynx and nasopharynx parts of the throat",
"domain": "textbook",
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,31 +48,33 @@ examples: |

For this {domain} domain here are some sample questions:
[Start of Question]
{question_1}
{icl_query_1}
[End of Question]
[Start of Response]
{response_1}
{icl_response_1}
[End of Response]

[Start of Question]
{question_2}
{icl_query_2}
[End of Question]
[Start of Response]
{response_2}
{icl_response_2}
[End of Response]

[Start of Question]
{question_3}
{icl_query_3}
[End of Question]
[Start of Response]
{response_3}
{icl_response_3}
[End of Response]

generation: |
Now generate the question and answer pairs, remember to follow the principles mentioned above and use the same format as the examples. Remember to use the same style and format as the example above.

Here is the document:
{document}

generation: |
Now generate the question and answer pairs, remember to follow the principles mentioned above and use the same format as the examples. Remember to use the same style and format as the example above. Return each question between [Start of Question] and [End of Question] tags and answer between [Start of Response] and [End of Response] tags.
Return each question between [Start of Question] and [End of Question] tags and answer between [Start of Response] and [End of Response] tags.

start_tags: ["[Start of Question]", "[Start of Response]"]
end_tags: ["[End of Question]", "[End of Response]"]
94 changes: 91 additions & 3 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from typing import Any, Dict
import re

# Third Party
Expand Down Expand Up @@ -56,9 +57,8 @@
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
all_matches = re.findall(pattern, generated_string, re.DOTALL)
matches[output_col] = (
[match.strip() for match in all_matches] if all_matches else None
[match.strip() for match in all_matches] if all_matches else []
)

return matches

def _generate(self, samples, **gen_kwargs) -> list:
Expand Down Expand Up @@ -86,7 +86,7 @@
if (num_samples is not None) and ("num_samples" not in samples.column_names):
samples = samples.add_column("num_samples", [num_samples] * len(samples))

# validate the each sample
# validate each sample
for sample in samples:
if not self._validate(self.prompt_template, sample):
return None
Expand All @@ -107,3 +107,91 @@
new_data.append({**sample, **dict(zip(parsed_outputs.keys(), values))})

return Dataset.from_list(new_data)


class ConditionalLLMBlock(LLMBlock):
def __init__(
self,
block_name,
config_paths,
client,
model_id,
output_cols,
selector_column_name,
parser_name,
model_prompt="{prompt}",
**batch_kwargs,
) -> None:
super().__init__(
block_name,
config_paths[0][0],
client,
model_id,
output_cols,
model_prompt=model_prompt,
**batch_kwargs,
)
self.selector_column_name = selector_column_name
self.prompt_template = {}
self.parser_name = parser_name
if len(config_paths) == 1 and config_paths[0][1] == "All":
self.prompt_template = self.prompt_struct.format(**self.block_config)
else:
for config, config_key in config_paths:
self.prompt_template[config_key] = self.prompt_struct.format(
**self._load_config(config)
)

def _parse(self, generated_string):
if self.parser_name == "default":
return super()._parse(generated_string)
if self.parser_name == "multi-line-logical-section":
return {
self.output_cols[0]: self.extract_multiline_logical_section(
generated_string
)
}

def extract_multiline_logical_section(self, text):
"""
Extracts multi-line points from the provided text into a list, removing the point numbers.

Args:
text (str): The input text containing multi-line points.

Returns:
list: A list of multi-line points without the point numbers.
"""
pattern = re.compile(
r"## Logical Section \d+: (.*?)(?=## Logical Section \d+:|$)", re.DOTALL
)
sections = pattern.findall(text)

return sections

def _generate(self, samples, **gen_kwargs) -> str:
if isinstance(self.prompt_template, dict):
prompts = [
self.model_prompt.format(
prompt=self.prompt_template[sample[self.selector_column_name]]
.format(**sample)
.strip()
)
for sample in samples
]
else:
prompts = [
self.model_prompt.format(
prompt=self.prompt_template.format(**sample).strip()
)
for sample in samples
]
response = self.client.completions.create(
prompt=prompts, **{**self.defaults, **gen_kwargs}
)
return [choice.text.strip() for choice in response.choices]

def _validate(self, prompt_template: str, input_dict: Dict[str, Any], extra_arg=None) -> bool:

Check warning on line 194 in src/instructlab/sdg/llmblock.py

View workflow job for this annotation

GitHub Actions / lint

W0221: Number of parameters was 2 in 'Block._validate' and is now 4 in overriding 'ConditionalLLMBlock._validate' method (arguments-differ)

Check warning on line 194 in src/instructlab/sdg/llmblock.py

View workflow job for this annotation

GitHub Actions / lint

W0613: Unused argument 'extra_arg' (unused-argument)

Check warning on line 194 in src/instructlab/sdg/llmblock.py

View workflow job for this annotation

GitHub Actions / lint

W0221: Number of parameters was 2 in 'Block._validate' and is now 4 in overriding 'ConditionalLLMBlock._validate' method (arguments-differ)

Check warning on line 194 in src/instructlab/sdg/llmblock.py

View workflow job for this annotation

GitHub Actions / lint

W0613: Unused argument 'extra_arg' (unused-argument)
if isinstance(prompt_template, dict):
prompt_template = prompt_template[input_dict[self.selector_column_name]]
return super()._validate(prompt_template, input_dict)
5 changes: 3 additions & 2 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def generate(self, dataset) -> Dataset:
for block_prop in self.chained_blocks:
block_type = block_prop["block_type"]
block_config = block_prop["block_config"]
drop_columns = block_prop.get("drop_columns", None)
drop_columns = block_prop.get("drop_columns", [])
gen_kwargs = block_prop.get("gen_kwargs", {})
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
block = block_type(**block_config)
Expand All @@ -50,8 +50,9 @@ def generate(self, dataset) -> Dataset:

dataset = block.generate(dataset, **gen_kwargs)

drop_columns_in_ds = [e for e in drop_columns if e in dataset.column_names]
if drop_columns:
dataset = dataset.remove_columns(drop_columns)
dataset = dataset.remove_columns(drop_columns_in_ds)

if drop_duplicates_cols:
dataset = self._drop_duplicates(dataset, cols=drop_duplicates_cols)
Expand Down
10 changes: 7 additions & 3 deletions src/instructlab/sdg/utilblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@


class SamplePopulatorBlock(Block):
def __init__(self, config_paths, column_name, **batch_kwargs) -> None:
super().__init__(block_name=self.__class__.__name__)
def __init__(self, config_paths, column_name, post_fix="", **batch_kwargs) -> None:
super().__init__(block_name=self.__class__.__name__) # Call the base class's __init__
self.configs = {}
for config in config_paths:
if post_fix:
config_name = config.replace(".yaml", f"_{post_fix}.yaml")
else:
config_name = config
config_key = config.split("/")[-1].split(".")[0]
self.configs[config_key] = self._load_config(config)
self.configs[config_key] = self._load_config(config_name)
self.column_name = column_name
self.num_procs = batch_kwargs.get("num_procs", 8)

Expand Down
Loading