Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into main
  • Loading branch information
IlyasMoutawwakil committed Sep 15, 2023
2 parents 3bf52a5 + d0194d1 commit 917f863
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 66 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Also, for integer parameters like `batch_size`, one can specify a range of value
optimum-benchmark --config-dir examples --config-name pytorch_bert -m device=cpu,cuda benchmark.input_shapes.batch_size='range(1,10,step=2)'
```

## Reporting benchamrk results (WIP)
## Reporting benchmark results (WIP)

To aggregate the results of a benchmark (run(s) or sweep(s)), you can use the `optimum-report` command.

Expand All @@ -135,7 +135,7 @@ You can also reuse some components of the reporting script for your use case (ex
## Configurations structure

You can create custom configuration files following the [examples here](examples).
You can also use `hydra`'s [composition](https://hydra.cc/docs/0.11/tutorial/composition/) with a base configuratin ([`examples/pytorch_bert.yaml`](examples/pytorch_bert.yaml) for example) and override/define parameters.
You can also use `hydra`'s [composition](https://hydra.cc/docs/0.11/tutorial/composition/) with a base configuration ([`examples/pytorch_bert.yaml`](examples/pytorch_bert.yaml) for example) and override/define parameters.

To create a configuration that uses a `wav2vec2` model and `onnxruntime` backend, it's as easy as:

Expand Down
4 changes: 2 additions & 2 deletions examples/whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Where `${device}` is either `cpu` or `cuda`.

## Metrics

Fo this benchmark I tried to compare `whisper-base` model's throughputs (forward and generate).
For this benchmark I tried to compare `whisper-base` model's throughputs (forward and generate).

Forward throughput is measured in `samples/second` with the formula `number_processed_samples / total_time`.
Where `number_processed_samples = batch_size * number_forward_passes` is the number of samples processed by the model in `total_time`.
Expand All @@ -23,7 +23,7 @@ Where `number_generated_tokens = batch_size * num_tokens * number_generate_passe

## Search Space

To be exhaustive, I benchmarked different auto optimization configurations supported by Optimum on GPU & CPU and auto quantization configrations on CPU only.
To be exhaustive, I benchmarked different auto optimization configurations supported by Optimum on GPU & CPU and auto quantization configurations on CPU only.

I also added `benchmark.batch_size=64,128 benchmark.new_tokens=10,100` to compare behavior across different batch sizes and number of generated tokens.

Expand Down
14 changes: 7 additions & 7 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
class Backend(Generic[BackendConfigT], ABC):
NAME: ClassVar[str]

# instance variables withouth default values https://stackoverflow.com/a/44962662
# instance variables without default values https://stackoverflow.com/a/44962662
config: BackendConfigT
pretrained_model: Union["PreTrainedModel", "Pipeline"]
pretrained_processor: Optional["PreTrainedProcessor"]
Expand Down Expand Up @@ -117,7 +117,7 @@ def check_continuous_isolation(self) -> None:
else:
device_ids = list(map(int, CUDA_VISIBLE_DEVICES.split(",")))

LOGGER.info(f"\t+ Checking contineous device(s) isolation of CUDA device(s): {device_ids}")
LOGGER.info(f"\t+ Checking continuous device(s) isolation of CUDA device(s): {device_ids}")
self.isolation_thread = Process(
target=check_only_this_process_is_running_on_cuda_device,
args=(device_ids, os.getpid()),
Expand Down Expand Up @@ -165,13 +165,13 @@ def prepare_input(self, input: Dict[str, Any]) -> Dict[str, Any]:

return input

# compiling in openvino requires input shapes
def prepare_for_inference(self, input_shapes: Dict[str, int]) -> Dict[str, Any]:
# compiling in openvino requires input shapes, trt ep requires max tokens, etc.
def prepare_for_inference(self, **kwargs) -> None:
pass

# symbolic tracing in transformers requires input names
def prepare_for_profiling(self, input_names: List[str]) -> Dict[str, Any]:
pass
# # symbolic tracing in transformers requires input names
# def prepare_for_profiling(self, input_names: List[str]) -> Dict[str, Any]:
# pass

def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput":
return self.pretrained_model(**input, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion optimum_benchmark/backends/neural_compressor/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any

self.incmodel_class = get_class(TASKS_TO_INCMODELS[self.task])
LOGGER.info(
f"\t+ Infered INCModel {self.incmodel_class.__name__} for task {self.task} and model_type {self.model_type}"
f"\t+ Inferred INCModel {self.incmodel_class.__name__} for task {self.task} and model_type {self.model_type}"
)

def validate_device(self) -> None:
Expand Down
32 changes: 25 additions & 7 deletions optimum_benchmark/backends/onnxruntime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any

ortmodel_name = self.ortmodel_class.__name__
LOGGER.info(
f"\t+ Infered ORTModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}"
f"\t+ Inferred ORTModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}"
)

def validate_device(self) -> None:
Expand All @@ -71,7 +71,7 @@ def configure(self, config: ORTConfig) -> None:
self.torch_dtype = getattr(torch, self.config.torch_dtype) if self.config.torch_dtype is not None else None

###### Training with ORTModule ######
# ort-training is basically a different package so we might need to seperate these two backends in the future
# ort-training is basically a different package so we might need to separate these two backends in the future
if not self.config.use_inference_session:
if self.config.no_weights:
self.load_automodel_from_config()
Expand Down Expand Up @@ -102,6 +102,7 @@ def configure(self, config: ORTConfig) -> None:
# Some statefullness to handle the different combinations of options
self.export = self.config.export
self.use_merged = self.config.use_merged
self.provider_options = self.config.provider_options.copy()

if self.is_diffusion_pipeline():
self.load_ortmodel()
Expand All @@ -114,7 +115,7 @@ def configure(self, config: ORTConfig) -> None:
self.export = False
else:
if self.config.export:
self.use_merged = False # merging is handeled seperately
self.use_merged = False # merging is handled separately
self.load_automodel_from_pretrained() # creates automodel from pretrained
self.export_automodel() # exports automodel
self.export = False
Expand All @@ -131,11 +132,11 @@ def configure(self, config: ORTConfig) -> None:
if self.config.auto_quantization or self.config.quantization:
self.quantize_onnx_files()

self.load_ortmodel()
self.tmpdir.cleanup()
if not (self.config.provider == "TensorrtExecutionProvider" and self.is_text_generation_model()):
self.load_ortmodel()
self.tmpdir.cleanup()

def load_automodel_from_config(self) -> None:
# TODO: create no_weights tests
from accelerate import init_empty_weights

LOGGER.info("\t+ Loading AutoModel from config")
Expand Down Expand Up @@ -164,8 +165,8 @@ def load_ortmodel(self) -> None:
export=self.export,
provider=self.config.provider,
session_options=self.session_options,
provider_options=self.provider_options,
use_io_binding=self.config.use_io_binding,
provider_options=self.config.provider_options,
**self.ortmodel_kwargs,
**self.hub_kwargs,
)
Expand Down Expand Up @@ -311,6 +312,23 @@ def quantize_onnx_files(self) -> None:
)
self.model = quantized_model_path

def prepare_for_inference(self, **kwargs) -> None:
if self.config.provider == "TensorrtExecutionProvider" and self.is_text_generation_model():
max_new_tokens = kwargs["max_new_tokens"]
batch_size = kwargs["input_shapes"]["batch_size"]
sequence_length = kwargs["input_shapes"]["sequence_length"]

LOGGER.info("\t+ Creating dynamic shapes for Tensorrt engine, loading will take a while")
self.provider_options = {
**self.provider_options,
"trt_profile_min_shapes": f"input_ids:{batch_size}x{sequence_length},attention_mask:{batch_size}x{sequence_length}",
"trt_profile_max_shapes": f"input_ids:{batch_size}x{sequence_length + max_new_tokens},attention_mask:{batch_size}x{sequence_length + max_new_tokens}",
"trt_profile_opt_shapes": f"input_ids:{batch_size}x{sequence_length + max_new_tokens},attention_mask:{batch_size}x{sequence_length + max_new_tokens}",
}

self.load_ortmodel()
self.tmpdir.cleanup()

def prepare_for_profiling(self, input_names: List[str]) -> None:
LOGGER.info("Preparing model for profiling")
LOGGER.info("\t+ Wrapping model inside profiler")
Expand Down
41 changes: 23 additions & 18 deletions optimum_benchmark/backends/onnxruntime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,29 @@
def infer_device_id(device: str) -> int:
"""Infer the device id from the given device string."""
if "cuda" in device:
# here we resolve conflicts between CUDA_VISIBLE_DEVICES and pytorch's indexing
# e.g. CUDA_VISIBLE_DEVICES=1 and device=cuda:0 should return 1
CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if CUDA_VISIBLE_DEVICES is None:
if ":" in device:
return int(device.split(":")[1])
else:
return 0
if ":" in device:
# either CUDA_VISIBLE_DEVICES is set or device is set to cuda:0
return int(device.split(":")[1])
else:
if ":" in device:
return int(CUDA_VISIBLE_DEVICES.split(",")[int(device.split(":")[1])])
else:
return int(CUDA_VISIBLE_DEVICES.split(",")[0])
# device is set to cuda
return 0
elif device == "cpu":
return -1
else:
raise ValueError(f"Unknown device: {device}")


DEVICE_PROVIDER_MAP = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
}

OmegaConf.register_new_resolver("onnxruntime_version", onnxruntime_version)
OmegaConf.register_new_resolver("is_gpu", lambda device: "cuda" in device)
OmegaConf.register_new_resolver("infer_device_id", lambda device: infer_device_id(device))
OmegaConf.register_new_resolver("infer_provider", lambda device: DEVICE_PROVIDER_MAP[device])
OmegaConf.register_new_resolver("is_profiling", lambda benchmark_name: benchmark_name == "profiling")
OmegaConf.register_new_resolver(
"infer_provider", lambda device: "CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
"io_bind", lambda provider: provider in ["CPUExecutionProvider", "CUDAExecutionProvider"]
)


Expand Down Expand Up @@ -97,6 +95,11 @@ def infer_device_id(device: str) -> int:
"preprocess_class": "optimum_benchmark.preprocessors.glue.GluePreprocessor",
}

TRT_PROVIDER_OPTIONS = {
"trt_engine_cache_enable": True,
"trt_engine_cache_path": "tmp/trt_cache",
}


@dataclass
class ORTConfig(BackendConfig):
Expand All @@ -114,12 +117,10 @@ class ORTConfig(BackendConfig):

# provider options
provider: str = "${infer_provider:${device}}"
device_id: Optional[int] = "${oc.deprecated:backend.provider_options.device_id}"
provider_options: Dict[str, Any] = field(default_factory=lambda: {"device_id": "${infer_device_id:${device}}"})

# inference options
use_io_binding: bool = "${is_gpu:${device}}"
enable_profiling: bool = "${oc.deprecated:backend.session_options.enable_profiling}"
use_io_binding: bool = "${io_bind:${device}}"
session_options: Dict[str, Any] = field(
default_factory=lambda: {"enable_profiling": "${is_profiling:${benchmark.name}}"}
)
Expand All @@ -144,7 +145,7 @@ class ORTConfig(BackendConfig):
auto_quantization: Optional[str] = None
auto_quantization_config: Dict[str, Any] = field(default_factory=dict)

# ort-training is basically a different package so we might need to seperate these two backends in the future
# ort-training is basically a different package so we might need to separate these two backends in the future
use_inference_session: bool = "${is_inference:${benchmark.name}}"

# training options
Expand All @@ -161,6 +162,10 @@ def __post_init__(self):
if not self.no_weights and not self.export and self.torch_dtype is not None:
raise NotImplementedError("Can't convert an exported model's weights to a different dtype.")

if self.provider == "TensorrtExecutionProvider":
self.provider_options = OmegaConf.to_object(OmegaConf.merge(TRT_PROVIDER_OPTIONS, self.provider_options))
os.makedirs(self.provider_options["trt_engine_cache_path"], exist_ok=True)

if self.optimization:
self.optimization_config = OmegaConf.to_object(
OmegaConf.merge(OPTIMIZATION_CONFIG, self.optimization_config)
Expand Down
7 changes: 5 additions & 2 deletions optimum_benchmark/backends/openvino/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any

self.ovmodel_class = get_class(TASKS_TO_OVMODEL[self.task])
ortmodel_name = self.ovmodel_class.__name__
LOGGER.info(f"\t+ Infered OVModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}")
LOGGER.info(
f"\t+ Inferred OVModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}"
)

def validate_task(self) -> None:
if self.task not in TASKS_TO_OVMODEL:
Expand Down Expand Up @@ -95,7 +97,8 @@ def quantize_automodel(self) -> None:
)
self.model = quantized_model_path

def prepare_for_inference(self, input_shapes: Dict[str, int]) -> None:
def prepare_for_inference(self, **kwargs) -> None:
input_shapes = kwargs["input_shapes"]
if self.config.reshape:
static_shapes = {
key: value
Expand Down
4 changes: 2 additions & 2 deletions optimum_benchmark/backends/optimum_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def main_export(
and not is_stable_diffusion
and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx")
):
if original_task == "auto": # Make -with-past the default if --task was not explicitely specified
if original_task == "auto": # Make -with-past the default if --task was not explicitly specified
task = task + "-with-past"
else:
logger.info(
Expand Down Expand Up @@ -328,7 +328,7 @@ def main_export(
if model.config.is_encoder_decoder and task.startswith("text-generation"):
raise ValueError(
f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report"
f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model,"
f"at https://github.com/huggingface/optimum, if --task was explicitly passed, make sure you selected the right task for the model,"
f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`."
)

Expand Down
2 changes: 1 addition & 1 deletion optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any
super().__init__(model, task, device, hub_kwargs)

automodel = self.automodel_class.__name__
LOGGER.info(f"\t+ Infered AutoModel class {automodel} for task {self.task} and model_type {self.model_type}")
LOGGER.info(f"\t+ Inferred AutoModel class {automodel} for task {self.task} and model_type {self.model_type}")

def configure(self, config: PyTorchConfig) -> None:
super().configure(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any
self.validate_task()

automodel = self.automodel_class.__name__
LOGGER.info(f"\t+ Infered AutoModel class {automodel} for task {self.task} and model_type {self.model_type}")
LOGGER.info(f"\t+ Inferred AutoModel class {automodel} for task {self.task} and model_type {self.model_type}")

def validate_task(self) -> None:
if self.task not in ["text-generation", "text2text-generation"]:
Expand Down Expand Up @@ -144,7 +144,7 @@ def load_model_from_pretrained(self) -> None:
self.pretrained_model = self.automodel_class.from_pretrained(self.model, **self.hub_kwargs)

def modify_generation_config(self) -> None:
# this should, theorically, make the generated output's sequence length fully controlled by max_new_tokens
# this should, theoretically, make the generated output's sequence length fully controlled by max_new_tokens
# instead of stopping at the first eos_token_id/pad_token_id
generation_config = GenerationConfig.from_pretrained(self.model, **self.hub_kwargs)
generation_config.eos_token_id = -100
Expand Down Expand Up @@ -194,7 +194,7 @@ def clean(self) -> None:
super().clean()

if hasattr(self, "tgi_container"):
LOGGER.info("\t+ Stoping TGI container")
LOGGER.info("\t+ Stopping TGI container")
self.tgi_container.stop()
LOGGER.info("\t+ Waiting for TGI container to stop")
self.tgi_container.wait()
Expand Down
9 changes: 6 additions & 3 deletions optimum_benchmark/benchmarks/inference/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def run(self, backend: "Backend") -> None:
input_shapes=self.config.input_shapes,
)

# openvino requires compiling with static shapes and trt ep requires max tokens
backend.prepare_for_inference(
input_shapes=self.config.input_shapes,
max_new_tokens=self.config.generate_kwargs.get("max_new_tokens", 0),
)

# run forward pass tracking
self.run_forward_tracking(backend)

Expand All @@ -60,9 +66,6 @@ def run_forward_tracking(self, backend: "Backend") -> None:
LOGGER.info("\t+ Preparing input for the forward pass")
forward_input = backend.prepare_input(forward_input)

# for backends that require compilation with static shapes
backend.prepare_for_inference(input_shapes=self.config.input_shapes)

LOGGER.info("\t+ Warming up the forward pass")
for _ in range(self.config.warmup_runs):
_ = backend.forward(forward_input, self.config.forward_kwargs)
Expand Down
23 changes: 6 additions & 17 deletions optimum_benchmark/benchmarks/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class InferenceConfig(BenchmarkConfig):
# benchmark options
duration: int = 10
warmup_runs: int = 10
benchmark_duration: Optional[int] = None # deprecated

# additional/optional metrics
memory: bool = False
Expand All @@ -57,9 +56,7 @@ class InferenceConfig(BenchmarkConfig):
},
)

# TODO: deprecate this and use `benchamrk.generate_kwargs`
new_tokens: Optional[int] = None

can_diffuse: bool = "${can_diffuse:${task}}"
can_generate: bool = "${can_generate:${task}}"

Expand All @@ -79,17 +76,9 @@ def __post_init__(self):
if self.generate_kwargs["max_new_tokens"] != self.generate_kwargs["min_new_tokens"]:
raise ValueError("`max_new_tokens` and `min_new_tokens` must be equal for fixed length output.")

if self.new_tokens is not None:
LOGGER.warning(
"The `new_tokens` option is deprecated, please use `generate_kwargs` instead. "
"`generate_kwargs.max_new_tokens` and `generate_kwargs.min_new_tokens` will be set to the value of `new_tokens`."
)
self.generate_kwargs["max_new_tokens"] = self.new_tokens
self.generate_kwargs["min_new_tokens"] = self.new_tokens

if self.benchmark_duration is not None:
LOGGER.warning(
"The `benchmark_duration` option is deprecated, please use `duration` instead. "
"`duration` will be set to the value of `benchmark_duration`."
)
self.duration = self.benchmark_duration
if self.new_tokens is not None:
LOGGER.info(
f"`new_tokens` was set to {self.new_tokens}. `max_new_tokens` and `min_new_tokens` will be set to {self.new_tokens}."
)
self.generate_kwargs["max_new_tokens"] = self.new_tokens
self.generate_kwargs["min_new_tokens"] = self.new_tokens
Loading

0 comments on commit 917f863

Please sign in to comment.