Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

feat: add inference and evaluation script with dataset transformations #733

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ WORKDIR /build/fairscale
RUN git checkout fixing_memory_issues_with_keeping_overlap_may24
RUN pip3 install -e .

RUN pip install \
py-rouge==1.1 \
rouge_score==0.1.2 \
parlai==1.7.1 \
evaluate==0.4.0

ENV NLTK_DATA="/usr/share/nltk_data"
RUN python -c "import nltk; nltk.download('punkt', download_dir='${NLTK_DATA}')"

# Install metaseq
WORKDIR /build
RUN git clone https://github.com/facebookresearch/metaseq.git
Expand Down
765 changes: 765 additions & 0 deletions metaseq/cli/inference.py

Large diffs are not rendered by default.

18 changes: 3 additions & 15 deletions metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from metaseq.file_io import PathManager
from metaseq.logging import meters, metrics, progress_bar
from metaseq.trainer import Trainer
from metaseq.utils import flatten_config

logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
Expand Down Expand Up @@ -75,7 +76,7 @@ def main(cfg: DictConfig) -> None:
# TODO(roller): only works when launched with a sweep script
# should fix that
OmegaConf.save(
config=_flatten_config(cfg),
config=flatten_config(cfg),
f=os.path.join(os.environ["METASEQ_SAVE_DIR"], "config.yml"),
)

Expand Down Expand Up @@ -288,7 +289,7 @@ def train(
"WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
),
)
progress.update_config(_flatten_config(cfg))
progress.update_config(flatten_config(cfg))

trainer.begin_epoch(epoch_itr.epoch)
valid_subsets = cfg.dataset.valid_subset.split(",")
Expand Down Expand Up @@ -411,19 +412,6 @@ def train(
return valid_losses, should_stop


def _flatten_config(cfg: DictConfig):
config = OmegaConf.to_container(cfg)
# remove any legacy Namespaces and replace with a single "args"
namespace = None
for k, v in list(config.items()):
if isinstance(v, argparse.Namespace):
namespace = v
del config[k]
if namespace is not None:
config["args"] = vars(namespace)
return config


def validate_and_save(
cfg: DictConfig,
trainer: Trainer,
Expand Down
Empty file.
21 changes: 21 additions & 0 deletions metaseq/data/datasets/cnn_dm_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any, Dict

from metaseq.data.datasets import openai_generated_transformers, shared_transformers
from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem


def before_transforming_into_metaseq_inference(raw_dict: Any) -> OAITeacherGeneratedDatasetItem:
item: OAITeacherGeneratedDatasetItem = raw_dict

item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item)
item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="</s>")

return item


def convert_teacher_domain_to_original_domain(model_output: str) -> str:
"""Convert output from text-davinci-003 to the format of original dataset"""

original_domain_output = shared_transformers.remove_non_alpha_from_beginning(model_output)

return original_domain_output
84 changes: 84 additions & 0 deletions metaseq/data/datasets/dataset_configurations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from metaseq.data.datasets import e2e_transformers, hellaswag_transformers, piqa_transformers, reddit_transformers, cnn_dm_transformers
from metaseq.data.datasets.types import CommonDatasetConfiguration, DatasetConfiguration, DatasetConfigurationTeacherGenerated, DatasetModelConfig, DatasetModelHooks, DatasetTeacherGeneratedDataHooks, IdentityDict

# Visual diagram of where hooks/functions are called during inference or data generation
# https://excalidraw.com/#json=zoAk_TdynBHQnP9vZufGm,ekcVg_HqiF79cAp58_HKRQ
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This visualization may be important for understanding

DATASET_CONFIGURATIONS = {
"cnn_dailymail":
DatasetConfiguration(
common=CommonDatasetConfiguration(metric_libraries=["grindstone", "coco"], metric_names=["bertscore", "rouge-L"]),
model_config=DatasetModelConfig(
model_hooks=DatasetModelHooks(
convert_model_type_output_to_original_domain=IdentityDict(
{
"distilled": cnn_dm_transformers.convert_teacher_domain_to_original_domain,
}
)
)
),
teacher_generated_config=DatasetConfigurationTeacherGenerated(
data_hooks=DatasetTeacherGeneratedDataHooks(
before_transforming_into_metaseq_inference=cnn_dm_transformers.before_transforming_into_metaseq_inference,
convert_test_target_to_original_domain_label=cnn_dm_transformers.convert_teacher_domain_to_original_domain
),
),
),
"e2e_nlg":
DatasetConfiguration(
common=CommonDatasetConfiguration(metric_libraries=["grindstone", "coco"], metric_names=["bertscore", "rouge-L"]),
teacher_generated_config=DatasetConfigurationTeacherGenerated(
data_hooks=DatasetTeacherGeneratedDataHooks(
before_transforming_into_metaseq_inference=e2e_transformers.before_transforming_into_metaseq_inference,
),
),
),
"hellaswag":
DatasetConfiguration(
common=CommonDatasetConfiguration(metric_libraries=["grindstone"], metric_names=["accuracy"]),
model_config=DatasetModelConfig(
model_hooks=DatasetModelHooks(
convert_model_type_output_to_original_domain=IdentityDict(
{
"distilled": hellaswag_transformers.hellaswag_convert_model_output_domain_to_original_domain,
}
)
)
),
teacher_generated_config=DatasetConfigurationTeacherGenerated(
data_hooks=DatasetTeacherGeneratedDataHooks(
before_transforming_into_metaseq_inference=hellaswag_transformers.
hellaswag_before_transforming_into_metaseq_inference,
convert_test_target_to_original_domain_label=hellaswag_transformers.
hellaswag_convert_model_output_domain_to_original_domain,
),
),
),
"piqa":
DatasetConfiguration(
common=CommonDatasetConfiguration(metric_libraries=["grindstone"], metric_names=["accuracy"]),
model_config=DatasetModelConfig(
model_hooks=DatasetModelHooks(
convert_model_type_output_to_original_domain=IdentityDict(
{
"distilled": piqa_transformers.model_output_to_orig,
}
)
)
),
teacher_generated_config=DatasetConfigurationTeacherGenerated(
data_hooks=DatasetTeacherGeneratedDataHooks(
before_transforming_into_metaseq_inference=piqa_transformers.preprocess_teacher_generated_data,
convert_test_target_to_original_domain_label=piqa_transformers.model_output_to_orig
)
)
),
"openai_tldr_reddit":
DatasetConfiguration(
common=CommonDatasetConfiguration(metric_libraries=["grindstone", "coco"], metric_names=["bertscore", "rouge-L"]),
teacher_generated_config=DatasetConfigurationTeacherGenerated(
data_hooks=DatasetTeacherGeneratedDataHooks(
before_transforming_into_metaseq_inference=reddit_transformers.before_transforming_into_metaseq_inference,
),
),
),
}
14 changes: 14 additions & 0 deletions metaseq/data/datasets/e2e_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any, Dict

from metaseq.data.datasets import openai_generated_transformers
from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem


def before_transforming_into_metaseq_inference(raw_dict: Any) -> OAITeacherGeneratedDatasetItem:
item: OAITeacherGeneratedDatasetItem = raw_dict

item = openai_generated_transformers.sanitize_beginning(item)
item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item)
item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="</s>")

return item
65 changes: 65 additions & 0 deletions metaseq/data/datasets/hellaswag_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Any, Dict

from metaseq.data.datasets import openai_generated_transformers
from metaseq.data.datasets.shared_transformers import get_first_number
from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem, OAITeacherGeneratedDatasetItemLogprobs


def _adjust_teacher_generated_format(data: Dict) -> OAITeacherGeneratedDatasetItem:
"""
The format that we have for the teacher generated data for hellaswag is not
what we expect as OpenAI output, so this function will transform the data
the correct shape.
"""
raw_response = data["response"]

# move human label from range [0,3] to [1,4] so we match with the
# download_hellaswag script
human_label = int(data.get("label", "-2"))
human_label += 1

return {
"source": data["prompt"],
"human": str(human_label),
"text": raw_response["text"],
"finish_reason": raw_response["finish_reason"],
"index": data["ind"],
"logprobs": raw_response["logprobs"],
}


def hellaswag_before_transforming_into_metaseq_inference(raw_dict: Any) -> OAITeacherGeneratedDatasetItem:
# Transform data to the correct OpenAI output shape
item = _adjust_teacher_generated_format(raw_dict)

# remove everything after EOS
item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item)

# replace EOS with </s>
item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="</s>")

# if found, remove everything after the token that has a closing bracket in
# it. This regex will match any token that has a closing bracked in it. For
# examples:
# - ") "
# - ")"
item = openai_generated_transformers.truncate_after_token(item, r".*?\).*?")

# verify that the target text contains a number. This will throw if not
# found and item will be skipped
try:
get_first_number(item["text"])
except AssertionError:
raise ValueError(f"Could not find a number in the generated text: {item['text']}")

return item


def hellaswag_convert_model_output_domain_to_original_domain(model_output: str) -> str:
# example model_output:
# ' (4) something something'
number_s = get_first_number(model_output)
choice_idx = number_s.strip()

# model generated label is in range [1,4]
return choice_idx
149 changes: 149 additions & 0 deletions metaseq/data/datasets/openai_generated_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import List

import regex

from metaseq.data.datasets.shared_transformers import remove_non_alpha_from_beginning
from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem


def sanitize_beginning(data: OAITeacherGeneratedDatasetItem) -> OAITeacherGeneratedDatasetItem:
"""
This function will remove all non-letter characters from the beginning of
the text and the tokens list.

Also note that this removes any non-alpha character/tokens from the
beginning of the text, which is not always desired.
"""

data["text"] = remove_non_alpha_from_beginning(data["text"])

# also remove non-letter tokens from the beginning of the tokens list
logprobs_dict = data["logprobs"]
token_list: List[str] = logprobs_dict["tokens"]

first_valid_idx = 0
while not token_list[first_valid_idx].strip().isalpha():
first_valid_idx += 1

for key in [
"tokens",
"token_logprobs",
"top_logprobs",
"text_offset",
]:
logprobs_dict[key] = logprobs_dict[key][first_valid_idx:]

return data


def remove_all_tokens_after_eos_sanitizer(
data: OAITeacherGeneratedDatasetItem, eos_token_name="<|endoftext|>"
) -> OAITeacherGeneratedDatasetItem:
"""
This function will remove all tokens after the first EOS token.

:param str eos_token_name: The name of the EOS token, defaults to
"<|endoftext|>"
"""

# it can be that there are some samples whose last token is not
# "<|endoftext|>". According to conversation with Subho here [1] we should
# remove all tokens after the endoftext
#
# [1]:
# https://teams.microsoft.com/l/message/19:[email protected]/1677880904286?tenantId=72f988bf-86f1-41af-91ab-2d7cd011db47&groupId=72b4c54c-a4e8-4f3e-b2c3-2bbeaf09e0ff&parentMessageId=1677718389179&teamName=Distillery&channelName=General&createdTime=1677880904286&allowXTenantAccess=false
logprobs_dict = data["logprobs"]
token_list: List[str] = logprobs_dict["tokens"]

# sanity check
assert eos_token_name in token_list

eos_index = token_list.index(eos_token_name)

# remove everything after this index (even the eos token)
for key in [
"tokens",
"token_logprobs",
"top_logprobs",
"text_offset",
]:
logprobs_dict[key] = logprobs_dict[key][:eos_index]

return data


def replace_eos_sanitizer(
data: OAITeacherGeneratedDatasetItem,
eos_replacement: str = "</s>",
eos_token_name="<|endoftext|>"
) -> OAITeacherGeneratedDatasetItem:
"""
This function will replace the EOS token name with the given replacement
string.

:param str eos_replacement: New name for the EOS token we want to use,
defaults to "</s>"
:param str eos_token_name: Old name that was being used for the EOS token,
defaults to "<|endoftext|>"
"""

logprobs_dict = data["logprobs"]

tokens = logprobs_dict["tokens"]
for t_idx in range(len(tokens)):
if tokens[t_idx] == eos_token_name:
tokens[t_idx] = eos_replacement

top_logprobs = logprobs_dict["top_logprobs"]
for logprob_dict in top_logprobs:
if eos_token_name in logprob_dict:
# remove existing item and assign it to eos_replacement token
logprob_dict[eos_replacement] = logprob_dict.pop(eos_token_name)

return data


def truncate_after_token(data: OAITeacherGeneratedDatasetItem, rgx: str) -> OAITeacherGeneratedDatasetItem:
"""
This function will truncate the text and tokens list AFTER the first token
that matches the given regex.

:param str rgx: The regex to match the token after which we should truncate
"""
token_matcher = regex.compile(rgx, flags=regex.MULTILINE | regex.DOTALL)

logprobs_dict = data["logprobs"]
token_list: List[str] = logprobs_dict["tokens"]

# find the first token that matches the regex
index_of_last_token = 0
seen_text = ""
for token in token_list:
if token_matcher.match(token):
break
index_of_last_token += 1
seen_text += token

# if we processed all tokens and exceeded the length of the list then we
# didn't find any token that matches the regex, so we raise an error
if index_of_last_token == len(token_list):
raise ValueError(f"Could not find any token that matches the regex {rgx}.")

# right now we're at the index of the token that matched the regex,
# meaning we want to drop everything after this index, so we add the
# last token to seen text and then increment the index
seen_text += token_list[index_of_last_token]
index_of_token_after_last = index_of_last_token + 1

for key in [
"tokens",
"token_logprobs",
"top_logprobs",
"text_offset",
]:
logprobs_dict[key] = logprobs_dict[key][:index_of_token_after_last]

# now we need to truncate the text as well
data["text"] = seen_text

return data
Loading