This repository has been archived by the owner on Nov 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 726
feat: add inference and evaluation script with dataset transformations #733
Open
mattmazzola
wants to merge
12
commits into
facebookresearch:main
Choose a base branch
from
mattmazzola:mattm/inference-scripts
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
3317dea
Add inference CLI script
mattmazzola 0a82a95
Add more missing components
mattmazzola d99055e
Add data transformers
mattmazzola 68cdbb1
More missed files
mattmazzola 7e1c21d
buld build typo
mattmazzola 64602f8
remove generation metrics init
mattmazzola 0792397
rename to strip_token
mattmazzola 655c89f
reset to 0 when generating tokens
mattmazzola 788902f
Update comment
mattmazzola bb99337
Add comment to clarify dataset loading
mattmazzola 39d6e05
remove comment
mattmazzola 14c6c93
remove aim
mattmazzola File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import Any, Dict | ||
|
||
from metaseq.data.datasets import openai_generated_transformers, shared_transformers | ||
from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem | ||
|
||
|
||
def before_transforming_into_metaseq_inference(raw_dict: Any) -> OAITeacherGeneratedDatasetItem: | ||
item: OAITeacherGeneratedDatasetItem = raw_dict | ||
|
||
item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item) | ||
item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="</s>") | ||
|
||
return item | ||
|
||
|
||
def convert_teacher_domain_to_original_domain(model_output: str) -> str: | ||
"""Convert output from text-davinci-003 to the format of original dataset""" | ||
|
||
original_domain_output = shared_transformers.remove_non_alpha_from_beginning(model_output) | ||
|
||
return original_domain_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from metaseq.data.datasets import e2e_transformers, hellaswag_transformers, piqa_transformers, reddit_transformers, cnn_dm_transformers | ||
from metaseq.data.datasets.types import CommonDatasetConfiguration, DatasetConfiguration, DatasetConfigurationTeacherGenerated, DatasetModelConfig, DatasetModelHooks, DatasetTeacherGeneratedDataHooks, IdentityDict | ||
|
||
# Visual diagram of where hooks/functions are called during inference or data generation | ||
# https://excalidraw.com/#json=zoAk_TdynBHQnP9vZufGm,ekcVg_HqiF79cAp58_HKRQ | ||
DATASET_CONFIGURATIONS = { | ||
"cnn_dailymail": | ||
DatasetConfiguration( | ||
common=CommonDatasetConfiguration(metric_libraries=["grindstone", "coco"], metric_names=["bertscore", "rouge-L"]), | ||
model_config=DatasetModelConfig( | ||
model_hooks=DatasetModelHooks( | ||
convert_model_type_output_to_original_domain=IdentityDict( | ||
{ | ||
"distilled": cnn_dm_transformers.convert_teacher_domain_to_original_domain, | ||
} | ||
) | ||
) | ||
), | ||
teacher_generated_config=DatasetConfigurationTeacherGenerated( | ||
data_hooks=DatasetTeacherGeneratedDataHooks( | ||
before_transforming_into_metaseq_inference=cnn_dm_transformers.before_transforming_into_metaseq_inference, | ||
convert_test_target_to_original_domain_label=cnn_dm_transformers.convert_teacher_domain_to_original_domain | ||
), | ||
), | ||
), | ||
"e2e_nlg": | ||
DatasetConfiguration( | ||
common=CommonDatasetConfiguration(metric_libraries=["grindstone", "coco"], metric_names=["bertscore", "rouge-L"]), | ||
teacher_generated_config=DatasetConfigurationTeacherGenerated( | ||
data_hooks=DatasetTeacherGeneratedDataHooks( | ||
before_transforming_into_metaseq_inference=e2e_transformers.before_transforming_into_metaseq_inference, | ||
), | ||
), | ||
), | ||
"hellaswag": | ||
DatasetConfiguration( | ||
common=CommonDatasetConfiguration(metric_libraries=["grindstone"], metric_names=["accuracy"]), | ||
model_config=DatasetModelConfig( | ||
model_hooks=DatasetModelHooks( | ||
convert_model_type_output_to_original_domain=IdentityDict( | ||
{ | ||
"distilled": hellaswag_transformers.hellaswag_convert_model_output_domain_to_original_domain, | ||
} | ||
) | ||
) | ||
), | ||
teacher_generated_config=DatasetConfigurationTeacherGenerated( | ||
data_hooks=DatasetTeacherGeneratedDataHooks( | ||
before_transforming_into_metaseq_inference=hellaswag_transformers. | ||
hellaswag_before_transforming_into_metaseq_inference, | ||
convert_test_target_to_original_domain_label=hellaswag_transformers. | ||
hellaswag_convert_model_output_domain_to_original_domain, | ||
), | ||
), | ||
), | ||
"piqa": | ||
DatasetConfiguration( | ||
common=CommonDatasetConfiguration(metric_libraries=["grindstone"], metric_names=["accuracy"]), | ||
model_config=DatasetModelConfig( | ||
model_hooks=DatasetModelHooks( | ||
convert_model_type_output_to_original_domain=IdentityDict( | ||
{ | ||
"distilled": piqa_transformers.model_output_to_orig, | ||
} | ||
) | ||
) | ||
), | ||
teacher_generated_config=DatasetConfigurationTeacherGenerated( | ||
data_hooks=DatasetTeacherGeneratedDataHooks( | ||
before_transforming_into_metaseq_inference=piqa_transformers.preprocess_teacher_generated_data, | ||
convert_test_target_to_original_domain_label=piqa_transformers.model_output_to_orig | ||
) | ||
) | ||
), | ||
"openai_tldr_reddit": | ||
DatasetConfiguration( | ||
common=CommonDatasetConfiguration(metric_libraries=["grindstone", "coco"], metric_names=["bertscore", "rouge-L"]), | ||
teacher_generated_config=DatasetConfigurationTeacherGenerated( | ||
data_hooks=DatasetTeacherGeneratedDataHooks( | ||
before_transforming_into_metaseq_inference=reddit_transformers.before_transforming_into_metaseq_inference, | ||
), | ||
), | ||
), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Any, Dict | ||
|
||
from metaseq.data.datasets import openai_generated_transformers | ||
from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem | ||
|
||
|
||
def before_transforming_into_metaseq_inference(raw_dict: Any) -> OAITeacherGeneratedDatasetItem: | ||
item: OAITeacherGeneratedDatasetItem = raw_dict | ||
|
||
item = openai_generated_transformers.sanitize_beginning(item) | ||
item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item) | ||
item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="</s>") | ||
|
||
return item |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import Any, Dict | ||
|
||
from metaseq.data.datasets import openai_generated_transformers | ||
from metaseq.data.datasets.shared_transformers import get_first_number | ||
from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem, OAITeacherGeneratedDatasetItemLogprobs | ||
|
||
|
||
def _adjust_teacher_generated_format(data: Dict) -> OAITeacherGeneratedDatasetItem: | ||
""" | ||
The format that we have for the teacher generated data for hellaswag is not | ||
what we expect as OpenAI output, so this function will transform the data | ||
the correct shape. | ||
""" | ||
raw_response = data["response"] | ||
|
||
# move human label from range [0,3] to [1,4] so we match with the | ||
# download_hellaswag script | ||
human_label = int(data.get("label", "-2")) | ||
human_label += 1 | ||
|
||
return { | ||
"source": data["prompt"], | ||
"human": str(human_label), | ||
"text": raw_response["text"], | ||
"finish_reason": raw_response["finish_reason"], | ||
"index": data["ind"], | ||
"logprobs": raw_response["logprobs"], | ||
} | ||
|
||
|
||
def hellaswag_before_transforming_into_metaseq_inference(raw_dict: Any) -> OAITeacherGeneratedDatasetItem: | ||
# Transform data to the correct OpenAI output shape | ||
item = _adjust_teacher_generated_format(raw_dict) | ||
|
||
# remove everything after EOS | ||
item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item) | ||
|
||
# replace EOS with </s> | ||
item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="</s>") | ||
|
||
# if found, remove everything after the token that has a closing bracket in | ||
# it. This regex will match any token that has a closing bracked in it. For | ||
# examples: | ||
# - ") " | ||
# - ")" | ||
item = openai_generated_transformers.truncate_after_token(item, r".*?\).*?") | ||
|
||
# verify that the target text contains a number. This will throw if not | ||
# found and item will be skipped | ||
try: | ||
get_first_number(item["text"]) | ||
except AssertionError: | ||
raise ValueError(f"Could not find a number in the generated text: {item['text']}") | ||
|
||
return item | ||
|
||
|
||
def hellaswag_convert_model_output_domain_to_original_domain(model_output: str) -> str: | ||
# example model_output: | ||
# ' (4) something something' | ||
number_s = get_first_number(model_output) | ||
choice_idx = number_s.strip() | ||
|
||
# model generated label is in range [1,4] | ||
return choice_idx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from typing import List | ||
|
||
import regex | ||
|
||
from metaseq.data.datasets.shared_transformers import remove_non_alpha_from_beginning | ||
from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem | ||
|
||
|
||
def sanitize_beginning(data: OAITeacherGeneratedDatasetItem) -> OAITeacherGeneratedDatasetItem: | ||
""" | ||
This function will remove all non-letter characters from the beginning of | ||
the text and the tokens list. | ||
|
||
Also note that this removes any non-alpha character/tokens from the | ||
beginning of the text, which is not always desired. | ||
""" | ||
|
||
data["text"] = remove_non_alpha_from_beginning(data["text"]) | ||
|
||
# also remove non-letter tokens from the beginning of the tokens list | ||
logprobs_dict = data["logprobs"] | ||
token_list: List[str] = logprobs_dict["tokens"] | ||
|
||
first_valid_idx = 0 | ||
while not token_list[first_valid_idx].strip().isalpha(): | ||
first_valid_idx += 1 | ||
|
||
for key in [ | ||
"tokens", | ||
"token_logprobs", | ||
"top_logprobs", | ||
"text_offset", | ||
]: | ||
logprobs_dict[key] = logprobs_dict[key][first_valid_idx:] | ||
|
||
return data | ||
|
||
|
||
def remove_all_tokens_after_eos_sanitizer( | ||
data: OAITeacherGeneratedDatasetItem, eos_token_name="<|endoftext|>" | ||
) -> OAITeacherGeneratedDatasetItem: | ||
""" | ||
This function will remove all tokens after the first EOS token. | ||
|
||
:param str eos_token_name: The name of the EOS token, defaults to | ||
"<|endoftext|>" | ||
""" | ||
|
||
# it can be that there are some samples whose last token is not | ||
# "<|endoftext|>". According to conversation with Subho here [1] we should | ||
# remove all tokens after the endoftext | ||
# | ||
# [1]: | ||
# https://teams.microsoft.com/l/message/19:[email protected]/1677880904286?tenantId=72f988bf-86f1-41af-91ab-2d7cd011db47&groupId=72b4c54c-a4e8-4f3e-b2c3-2bbeaf09e0ff&parentMessageId=1677718389179&teamName=Distillery&channelName=General&createdTime=1677880904286&allowXTenantAccess=false | ||
logprobs_dict = data["logprobs"] | ||
token_list: List[str] = logprobs_dict["tokens"] | ||
|
||
# sanity check | ||
assert eos_token_name in token_list | ||
|
||
eos_index = token_list.index(eos_token_name) | ||
|
||
# remove everything after this index (even the eos token) | ||
for key in [ | ||
"tokens", | ||
"token_logprobs", | ||
"top_logprobs", | ||
"text_offset", | ||
]: | ||
logprobs_dict[key] = logprobs_dict[key][:eos_index] | ||
|
||
return data | ||
|
||
|
||
def replace_eos_sanitizer( | ||
data: OAITeacherGeneratedDatasetItem, | ||
eos_replacement: str = "</s>", | ||
eos_token_name="<|endoftext|>" | ||
) -> OAITeacherGeneratedDatasetItem: | ||
""" | ||
This function will replace the EOS token name with the given replacement | ||
string. | ||
|
||
:param str eos_replacement: New name for the EOS token we want to use, | ||
defaults to "</s>" | ||
:param str eos_token_name: Old name that was being used for the EOS token, | ||
defaults to "<|endoftext|>" | ||
""" | ||
|
||
logprobs_dict = data["logprobs"] | ||
|
||
tokens = logprobs_dict["tokens"] | ||
for t_idx in range(len(tokens)): | ||
if tokens[t_idx] == eos_token_name: | ||
tokens[t_idx] = eos_replacement | ||
|
||
top_logprobs = logprobs_dict["top_logprobs"] | ||
for logprob_dict in top_logprobs: | ||
if eos_token_name in logprob_dict: | ||
# remove existing item and assign it to eos_replacement token | ||
logprob_dict[eos_replacement] = logprob_dict.pop(eos_token_name) | ||
|
||
return data | ||
|
||
|
||
def truncate_after_token(data: OAITeacherGeneratedDatasetItem, rgx: str) -> OAITeacherGeneratedDatasetItem: | ||
""" | ||
This function will truncate the text and tokens list AFTER the first token | ||
that matches the given regex. | ||
|
||
:param str rgx: The regex to match the token after which we should truncate | ||
""" | ||
token_matcher = regex.compile(rgx, flags=regex.MULTILINE | regex.DOTALL) | ||
|
||
logprobs_dict = data["logprobs"] | ||
token_list: List[str] = logprobs_dict["tokens"] | ||
|
||
# find the first token that matches the regex | ||
index_of_last_token = 0 | ||
seen_text = "" | ||
for token in token_list: | ||
if token_matcher.match(token): | ||
break | ||
index_of_last_token += 1 | ||
seen_text += token | ||
|
||
# if we processed all tokens and exceeded the length of the list then we | ||
# didn't find any token that matches the regex, so we raise an error | ||
if index_of_last_token == len(token_list): | ||
raise ValueError(f"Could not find any token that matches the regex {rgx}.") | ||
|
||
# right now we're at the index of the token that matched the regex, | ||
# meaning we want to drop everything after this index, so we add the | ||
# last token to seen text and then increment the index | ||
seen_text += token_list[index_of_last_token] | ||
index_of_token_after_last = index_of_last_token + 1 | ||
|
||
for key in [ | ||
"tokens", | ||
"token_logprobs", | ||
"top_logprobs", | ||
"text_offset", | ||
]: | ||
logprobs_dict[key] = logprobs_dict[key][:index_of_token_after_last] | ||
|
||
# now we need to truncate the text as well | ||
data["text"] = seen_text | ||
|
||
return data |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This visualization may be important for understanding