Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 30, 2024
1 parent 04f713f commit e833ffc
Showing 1 changed file with 21 additions and 60 deletions.
81 changes: 21 additions & 60 deletions optimum_benchmark/generators/task_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
import random
import string
from abc import ABC
from typing import Tuple
from typing import List, Tuple

# TODO: drop torch dependency and use numpy instead
import torch

LOGGER = logging.getLogger("generators")

DEFAULT_NUM_LABELS = 2
DEFAULT_VOCAB_SIZE = 2
DEFAULT_TYPE_VOCAB_SIZE = 2


class TaskGenerator(ABC):
def __init__(self, shapes, with_labels: bool):
Expand All @@ -17,61 +21,21 @@ def __init__(self, shapes, with_labels: bool):

@staticmethod
def generate_random_integers(min_value: int, max_value: int, shape: Tuple[int]):
if min_value is None:
LOGGER.warning("min_value is None, setting it to 0")
min_value = 0

if max_value is None:
LOGGER.warning("max_value is None, setting it to 2")
max_value = 2

if None in shape:
LOGGER.warning("shape contains None, setting it to (1, 1)")
shape = (1, 1)

return torch.randint(min_value, max_value, shape)

@staticmethod
def generate_random_floats(min_value: float, max_value: float, shape: Tuple[int]):
if min_value is None:
LOGGER.warning("min_value is None, setting it to 0")
min_value = 0

if max_value is None:
LOGGER.warning("max_value is None, setting it to 1")
max_value = 1

if None in shape:
LOGGER.warning("shape contains None, setting it to (1, 1)")
shape = (1, 1)

return torch.rand(shape) * (max_value - min_value) + min_value

@staticmethod
def generate_ranges(start: int, stop: int, shape: Tuple[int]):
if start is None:
LOGGER.warning("start is None, setting it to 0")
start = 0

if stop is None:
LOGGER.warning("stop is None, setting it to 1")
stop = 1

if None in shape:
LOGGER.warning("shape contains None, setting it to (1, 1)")
shape = (1, 1)

return torch.arange(start, stop).repeat(shape[0], 1)

@staticmethod
def generate_random_strings(shape: Tuple[int]):
if None in shape:
LOGGER.warning("shape contains None, setting it to (1, 1)")
shape = (1, 1)

def generate_random_strings(num_seq: int) -> List[str]:
return [
"".join(random.choice(string.ascii_letters + string.digits) for _ in range(shape[1]))
for _ in range(shape[0])
"".join(random.choice(string.ascii_letters + string.digits) for _ in range(random.randint(10, 100)))
for _ in range(num_seq)
]

def __call__(self):
Expand All @@ -82,7 +46,7 @@ class TextGenerator(TaskGenerator):
def input_ids(self):
return self.generate_random_integers(
min_value=0,
max_value=self.shapes["vocab_size"],
max_value=self.shapes["vocab_size"] or DEFAULT_VOCAB_SIZE,
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)

Expand All @@ -96,7 +60,7 @@ def attention_mask(self):
def token_type_ids(self):
return self.generate_random_integers(
min_value=0,
max_value=self.shapes["type_vocab_size"],
max_value=self.shapes["type_vocab_siza"] or DEFAULT_TYPE_VOCAB_SIZE,
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)

Expand Down Expand Up @@ -140,7 +104,7 @@ def input_features(self):
class TextClassificationGenerator(TextGenerator):
def labels(self):
return self.generate_random_integers(
min_value=0, max_value=self.shapes["num_labels"], shape=(self.shapes["batch_size"],)
min_value=0, max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"],)
)

def __call__(self):
Expand All @@ -165,7 +129,7 @@ class TokenClassificationGenerator(TextGenerator):
def labels(self):
return self.generate_random_integers(
min_value=0,
max_value=self.shapes["num_labels"],
max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS,
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)

Expand Down Expand Up @@ -292,7 +256,7 @@ def __call__(self):
class ImageClassificationGenerator(ImageGenerator):
def labels(self):
return self.generate_random_integers(
min_value=0, max_value=self.shapes["num_labels"], shape=(self.shapes["batch_size"],)
min_value=0, max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"],)
)

def __call__(self):
Expand All @@ -310,7 +274,9 @@ def labels(self):
return [
{
"class_labels": self.generate_random_integers(
min_value=0, max_value=self.shapes["num_labels"], shape=(self.shapes["num_queries"],)
min_value=0,
max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS,
shape=(self.shapes["num_queries"],),
),
"boxes": self.generate_random_floats(min_value=-1, max_value=1, shape=(self.shapes["num_queries"], 4)),
}
Expand All @@ -331,7 +297,7 @@ class SemanticSegmentationGenerator(ImageGenerator):
def labels(self):
return self.generate_random_integers(
min_value=0,
max_value=self.shapes["num_labels"],
max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS,
shape=(self.shapes["batch_size"], self.shapes["height"], self.shapes["width"]),
)

Expand All @@ -348,7 +314,7 @@ def __call__(self):
class AudioClassificationGenerator(AudioGenerator):
def labels(self):
return self.generate_random_integers(
min_value=0, max_value=self.shapes["num_labels"], shape=(self.shapes["batch_size"],)
min_value=0, max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"],)
)

def __call__(self):
Expand All @@ -365,7 +331,7 @@ class AutomaticSpeechRecognitionGenerator(AudioGenerator):
def labels(self):
return self.generate_random_integers(
min_value=0,
max_value=self.shapes["vocab_size"],
max_value=self.shapes["vocab_size"] or DEFAULT_TYPE_VOCAB_SIZE,
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)

Expand All @@ -381,7 +347,7 @@ def __call__(self):

class PromptGenerator(TaskGenerator):
def prompt(self):
return self.generate_random_strings(shape=(self.shapes["batch_size"], 10))
return self.generate_random_strings(num_seq=self.shapes["batch_size"])

def __call__(self):
dummy = {}
Expand All @@ -394,12 +360,7 @@ class FeatureExtractionGenerator(TextGenerator, ImageGenerator):
def __call__(self):
dummy = {}

if (
"num_channels" in self.shapes
and self.shapes["num_channels"] is not None
and "height" in self.shapes
and self.shapes["height"] is not None
):
if self.shapes.get("num_channels", None) is not None and self.shapes.get("height", None) is not None:
dummy["pixel_values"] = self.pixel_values()
else:
dummy["input_ids"] = self.input_ids()
Expand Down

0 comments on commit e833ffc

Please sign in to comment.