Skip to content

Commit

Permalink
added
Browse files Browse the repository at this point in the history
  • Loading branch information
tim_sockel committed Dec 18, 2023
1 parent 57bc36c commit 3436c1a
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions textattack/augmentation/augmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
===================
"""
import random
from collections import Counter

import tqdm

Expand Down Expand Up @@ -230,6 +231,98 @@ def augment_text_with_ids(self, text_list, id_list, show_progress=True):
all_id_list.extend([_id] * (1 + len(augmented_texts)))
return all_text_list, all_id_list

def augment_text_with_ids_evenly(
self,
text_list,
id_list,
additional_examples=0,
perfectly_even=True,
show_progress=True,
):
"""Supplements a list of text with more text data so that there is approximately
the same number of sentences for each label.
Each ID from `id_list` will be represented the same number of times
as the most frequent ID plus `additional_examples`.
If `perfectly_even` is set to `True`, every ID will be occurring exactly the same number of times (recommended,
but slightly slower).
Returns the augmented text along with the corresponding IDs for
each augmented example.
"""
if len(text_list) != len(id_list):
raise ValueError("List of text must be same length as list of IDs")
if additional_examples < 0:
raise ValueError("Additional examples must be non-negative")
all_text_list = []
all_id_list = []
examples_per_id = Counter(id_list)
max_examples = max(examples_per_id.values()) + additional_examples
diff_per_example = {k: max_examples - v for k, v in examples_per_id.items()}
original_transformations_per_example = self.transformations_per_example
remainders = {}
if show_progress:
text_list = tqdm.tqdm(text_list, desc="Augmenting data...")
for text, _id in zip(text_list, id_list):
# distribute augmentation of the original documents evenly
self.transformations_per_example = (
diff_per_example[_id] // examples_per_id[_id]
)
remainders[_id] = diff_per_example[_id] % examples_per_id[_id]
all_text_list.append(text)
all_id_list.append(_id)
if self.transformations_per_example > 0:
augmented_texts = []
while len(augmented_texts) < self.transformations_per_example:
augmented_texts.extend(self.augment(text))
all_text_list.extend(augmented_texts)
all_id_list.extend([_id] * len(augmented_texts))

if perfectly_even:
self.transformations_per_example = 1
# (1) add missing examples:
if show_progress:
added = tqdm.tqdm(
desc="Adding additional examples...", total=sum(remainders.values())
)
while any(remainders.values()):
for text, _id in zip(text_list, id_list):
if remainders[_id] > 0:
# add missing elements one-by-one
remainders[_id] -= 1
if show_progress:
added.update(1)
augmented_texts = self.augment(text)
all_text_list.extend(augmented_texts)
all_id_list.append(_id)
if show_progress:
added.close()
# (2) remove excess:
excess = {k: v - max_examples for k, v in Counter(all_id_list).items()}
new_id_list = []
new_text_list = []
if show_progress:
to_be_removed = int(sum([e > 0 for e in excess.values()]))
removed = tqdm.tqdm(
desc="Removing abundant examples...", total=to_be_removed
)
# count backwards so that the newer elements (most probably being augmented) are deleted first
for i in range(len(all_id_list) - 1, -1, -1):
if excess[all_id_list[i]] <= 0:
new_id_list.append(all_id_list[i])
new_text_list.append(all_text_list[i])
else:
# skip entry for new id and text list
excess[all_id_list[i]] -= 1
if show_progress:
removed.update(1)
all_id_list = new_id_list
all_text_list = new_text_list
if show_progress:
removed.close()

self.transformations_per_example = original_transformations_per_example
return all_text_list, all_id_list

def __repr__(self):
main_str = "Augmenter" + "("
lines = []
Expand Down

0 comments on commit 3436c1a

Please sign in to comment.