diff --git a/optimum_benchmark/generators/task_generator.py b/optimum_benchmark/generators/task_generator.py index 5f19c58a..d68b6a48 100644 --- a/optimum_benchmark/generators/task_generator.py +++ b/optimum_benchmark/generators/task_generator.py @@ -2,13 +2,17 @@ import random import string from abc import ABC -from typing import Tuple +from typing import List, Tuple # TODO: drop torch dependency and use numpy instead import torch LOGGER = logging.getLogger("generators") +DEFAULT_NUM_LABELS = 2 +DEFAULT_VOCAB_SIZE = 2 +DEFAULT_TYPE_VOCAB_SIZE = 2 + class TaskGenerator(ABC): def __init__(self, shapes, with_labels: bool): @@ -17,61 +21,21 @@ def __init__(self, shapes, with_labels: bool): @staticmethod def generate_random_integers(min_value: int, max_value: int, shape: Tuple[int]): - if min_value is None: - LOGGER.warning("min_value is None, setting it to 0") - min_value = 0 - - if max_value is None: - LOGGER.warning("max_value is None, setting it to 2") - max_value = 2 - - if None in shape: - LOGGER.warning("shape contains None, setting it to (1, 1)") - shape = (1, 1) - return torch.randint(min_value, max_value, shape) @staticmethod def generate_random_floats(min_value: float, max_value: float, shape: Tuple[int]): - if min_value is None: - LOGGER.warning("min_value is None, setting it to 0") - min_value = 0 - - if max_value is None: - LOGGER.warning("max_value is None, setting it to 1") - max_value = 1 - - if None in shape: - LOGGER.warning("shape contains None, setting it to (1, 1)") - shape = (1, 1) - return torch.rand(shape) * (max_value - min_value) + min_value @staticmethod def generate_ranges(start: int, stop: int, shape: Tuple[int]): - if start is None: - LOGGER.warning("start is None, setting it to 0") - start = 0 - - if stop is None: - LOGGER.warning("stop is None, setting it to 1") - stop = 1 - - if None in shape: - LOGGER.warning("shape contains None, setting it to (1, 1)") - shape = (1, 1) - return torch.arange(start, stop).repeat(shape[0], 1) @staticmethod - def generate_random_strings(shape: Tuple[int]): - if None in shape: - LOGGER.warning("shape contains None, setting it to (1, 1)") - shape = (1, 1) - + def generate_random_strings(num_seq: int) -> List[str]: return [ - "".join(random.choice(string.ascii_letters + string.digits) for _ in range(shape[1])) - for _ in range(shape[0]) + "".join(random.choice(string.ascii_letters + string.digits) for _ in range(random.randint(10, 100))) + for _ in range(num_seq) ] def __call__(self): @@ -82,7 +46,7 @@ class TextGenerator(TaskGenerator): def input_ids(self): return self.generate_random_integers( min_value=0, - max_value=self.shapes["vocab_size"], + max_value=self.shapes["vocab_size"] or DEFAULT_VOCAB_SIZE, shape=(self.shapes["batch_size"], self.shapes["sequence_length"]), ) @@ -96,7 +60,7 @@ def attention_mask(self): def token_type_ids(self): return self.generate_random_integers( min_value=0, - max_value=self.shapes["type_vocab_size"], + max_value=self.shapes["type_vocab_siza"] or DEFAULT_TYPE_VOCAB_SIZE, shape=(self.shapes["batch_size"], self.shapes["sequence_length"]), ) @@ -140,7 +104,7 @@ def input_features(self): class TextClassificationGenerator(TextGenerator): def labels(self): return self.generate_random_integers( - min_value=0, max_value=self.shapes["num_labels"], shape=(self.shapes["batch_size"],) + min_value=0, max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"],) ) def __call__(self): @@ -165,7 +129,7 @@ class TokenClassificationGenerator(TextGenerator): def labels(self): return self.generate_random_integers( min_value=0, - max_value=self.shapes["num_labels"], + max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"], self.shapes["sequence_length"]), ) @@ -292,7 +256,7 @@ def __call__(self): class ImageClassificationGenerator(ImageGenerator): def labels(self): return self.generate_random_integers( - min_value=0, max_value=self.shapes["num_labels"], shape=(self.shapes["batch_size"],) + min_value=0, max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"],) ) def __call__(self): @@ -310,7 +274,9 @@ def labels(self): return [ { "class_labels": self.generate_random_integers( - min_value=0, max_value=self.shapes["num_labels"], shape=(self.shapes["num_queries"],) + min_value=0, + max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, + shape=(self.shapes["num_queries"],), ), "boxes": self.generate_random_floats(min_value=-1, max_value=1, shape=(self.shapes["num_queries"], 4)), } @@ -331,7 +297,7 @@ class SemanticSegmentationGenerator(ImageGenerator): def labels(self): return self.generate_random_integers( min_value=0, - max_value=self.shapes["num_labels"], + max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"], self.shapes["height"], self.shapes["width"]), ) @@ -348,7 +314,7 @@ def __call__(self): class AudioClassificationGenerator(AudioGenerator): def labels(self): return self.generate_random_integers( - min_value=0, max_value=self.shapes["num_labels"], shape=(self.shapes["batch_size"],) + min_value=0, max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"],) ) def __call__(self): @@ -365,7 +331,7 @@ class AutomaticSpeechRecognitionGenerator(AudioGenerator): def labels(self): return self.generate_random_integers( min_value=0, - max_value=self.shapes["vocab_size"], + max_value=self.shapes["vocab_size"] or DEFAULT_TYPE_VOCAB_SIZE, shape=(self.shapes["batch_size"], self.shapes["sequence_length"]), ) @@ -381,7 +347,7 @@ def __call__(self): class PromptGenerator(TaskGenerator): def prompt(self): - return self.generate_random_strings(shape=(self.shapes["batch_size"], 10)) + return self.generate_random_strings(num_seq=self.shapes["batch_size"]) def __call__(self): dummy = {} @@ -394,12 +360,7 @@ class FeatureExtractionGenerator(TextGenerator, ImageGenerator): def __call__(self): dummy = {} - if ( - "num_channels" in self.shapes - and self.shapes["num_channels"] is not None - and "height" in self.shapes - and self.shapes["height"] is not None - ): + if self.shapes.get("num_channels", None) is not None and self.shapes.get("height", None) is not None: dummy["pixel_values"] = self.pixel_values() else: dummy["input_ids"] = self.input_ids()