Skip to content

Commit

Permalink
remove dp tp distinction
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 27, 2024
1 parent b64a514 commit 13bc8c0
Show file tree
Hide file tree
Showing 14 changed files with 181 additions and 268 deletions.
21 changes: 2 additions & 19 deletions optimum_benchmark/backends/ipex/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,14 @@ def automodel_kwargs(self) -> Dict[str, Any]:
if self.config.torch_dtype is not None:
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)

print(kwargs)

return kwargs

@property
def is_dp_distributed(self) -> bool:
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0:
raise ValueError(
f"Batch size {input_shapes['batch_size']} must be divisible by "
f"data parallel world size {torch.distributed.get_world_size()}"
)
# distributing batch size across processes
input_shapes["batch_size"] //= torch.distributed.get_world_size()

# registering input shapes for usage during model reshaping
self.input_shapes = input_shapes

return input_shapes

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

Expand Down
16 changes: 4 additions & 12 deletions optimum_benchmark/backends/onnxruntime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,20 +280,12 @@ def quantize_onnx_files(self) -> None:
if self.pretrained_config is not None:
self.pretrained_config.save_pretrained(self.quantized_model)

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0:
raise ValueError(
f"Batch size {input_shapes['batch_size']} must be divisible by "
f"data parallel world size {torch.distributed.get_world_size()}"
)
# distributing batch size across processes
input_shapes["batch_size"] //= torch.distributed.get_world_size()

return input_shapes
@property
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

Expand Down
49 changes: 21 additions & 28 deletions optimum_benchmark/backends/openvino/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def load(self) -> None:
if self.config.reshape:
static_shapes = {
key: value
for key, value in {**self.input_shapes, **self.model_shapes}.items()
for key, value in self.model_shapes.items()
if key in inspect.getfullargspec(self.pretrained_model.reshape).args
}
if ("sequence_length" in static_shapes) and ("height" in static_shapes) and ("width" in static_shapes):
Expand Down Expand Up @@ -135,20 +135,6 @@ def _load_ovmodel_with_no_weights(self) -> None:
self.config.export = original_export
self.config.model = original_model

@property
def is_dp_distributed(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

@property
def ovmodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if self.config.task in TEXT_GENERATION_TASKS:
kwargs["use_cache"] = self.config.use_cache
kwargs["use_merged"] = self.config.use_merged

return kwargs

def quantize_automodel(self) -> None:
self.logger.info("\t+ Attempting quantization")
self.quantized_model = f"{self.tmpdir.name}/quantized_model"
Expand Down Expand Up @@ -181,30 +167,37 @@ def quantize_automodel(self) -> None:
batch_size=1,
)

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0:
raise ValueError(
f"Batch size {input_shapes['batch_size']} must be divisible by "
f"data parallel world size {torch.distributed.get_world_size()}"
)
# distributing batch size across processes
input_shapes["batch_size"] //= torch.distributed.get_world_size()
@property
def ovmodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

# registering input shapes for usage during model reshaping
self.input_shapes = input_shapes
if self.config.task in TEXT_GENERATION_TASKS:
kwargs["use_cache"] = self.config.use_cache
kwargs["use_merged"] = self.config.use_merged

return input_shapes
return kwargs

@property
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

for key in list(inputs.keys()):
if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names:
inputs.pop(key)

if "input_ids" in inputs:
self.model_shapes.update(dict(zip(["batch_size", "sequence_length"], inputs["input_ids"].shape)))

if "pixel_values" in inputs:
self.model_shapes.update(
dict(zip(["batch_size", "num_channels", "height", "width"], inputs["pixel_values"].shape))
)

return inputs

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
Expand Down
45 changes: 12 additions & 33 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import deepspeed # type: ignore

if is_torch_distributed_available():
import torch.distributed
import torch.distributed # type: ignore

if is_zentorch_available():
import zentorch # type: ignore # noqa: F401
Expand Down Expand Up @@ -326,18 +326,6 @@ def process_quantization_config(self) -> None:
else:
raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized")

@property
def is_distributed(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

@property
def is_tp_distributed(self) -> bool:
return self.is_distributed and self.config.deepspeed_inference

@property
def is_dp_distributed(self) -> bool:
return self.is_distributed and not self.config.deepspeed_inference

@property
def is_quantized(self) -> bool:
return self.config.quantization_scheme is not None or (
Expand Down Expand Up @@ -407,35 +395,26 @@ def automodel_kwargs(self) -> Dict[str, Any]:

return kwargs

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0:
raise ValueError(
f"Batch size {input_shapes['batch_size']} must be divisible by "
f"data parallel world size {torch.distributed.get_world_size()}"
)
# distributing batch size across processes
input_shapes["batch_size"] //= torch.distributed.get_world_size()

if self.is_tp_distributed:
if torch.distributed.get_rank() != 0:
# zeroing throughput on other ranks
input_shapes["batch_size"] = 0

return input_shapes
@property
def split_between_processes(self) -> bool:
return (
is_torch_distributed_available()
and torch.distributed.is_initialized()
and not self.config.deepspeed_inference
)

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

if self.config.library == "timm":
inputs = {"x": inputs["pixel_values"]}

for key, value in inputs.items():
if isinstance(value, torch.Tensor):
inputs[key] = value.to(self.config.device)

if self.config.library == "timm":
inputs = {"x": inputs["pixel_values"]}

return inputs

@torch.inference_mode()
Expand Down
43 changes: 17 additions & 26 deletions optimum_benchmark/backends/transformers_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from contextlib import contextmanager
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Type, Union

import torch
import transformers
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoImageProcessor,
AutoModel,
AutoProcessor,
AutoTokenizer,
FeatureExtractionMixin,
Expand All @@ -17,9 +18,7 @@
SpecialTokensMixin,
)

from ..import_utils import is_torch_available

TASKS_TO_MODEL_LOADERS = {
TASKS_TO_AUTOMODEL_CLASS_NAMES = {
# text processing
"feature-extraction": "AutoModel",
"fill-mask": "AutoModelForMaskedLM",
Expand Down Expand Up @@ -57,34 +56,26 @@
"sentence-similarity": "feature-extraction",
}

if is_torch_available():
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {}
for task_name, model_loaders in TASKS_TO_MODEL_LOADERS.items():
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name] = {}

if isinstance(model_loaders, str):
model_loaders = (model_loaders,)

for model_loader_name in model_loaders:
model_loader_class = getattr(transformers, model_loader_name, None)
if model_loader_class is not None:
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name].update(
model_loader_class._model_mapping._model_mapping
)
else:
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {}


def get_transformers_automodel_loader_for_task(task: str, model_type: Optional[str] = None):
def get_transformers_automodel_class_for_task(task: str, model_type: Optional[str] = None) -> Type["AutoModel"]:
if task in SYNONYM_TASKS:
task = SYNONYM_TASKS[task]

if model_type is not None:
model_loader_name = TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task][model_type]
if task not in TASKS_TO_AUTOMODEL_CLASS_NAMES:
raise ValueError(f"Task {task} not supported")

if isinstance(TASKS_TO_AUTOMODEL_CLASS_NAMES[task], str):
return getattr(transformers, TASKS_TO_AUTOMODEL_CLASS_NAMES[task])
else:
model_loader_name = TASKS_TO_MODEL_LOADERS[task]
if model_type is None:
raise ValueError(f"Task {task} requires a model_type to be specified")

for automodel_class_name in TASKS_TO_AUTOMODEL_CLASS_NAMES[task]:
automodel_class = getattr(transformers, automodel_class_name)
if model_type in automodel_class._model_mapping._model_mapping:
return automodel_class

return getattr(transformers, model_loader_name)
raise ValueError(f"Task {task} not supported for model type {model_type}")


PretrainedProcessor = Union["FeatureExtractionMixin", "ImageProcessingMixin", "SpecialTokensMixin", "ProcessorMixin"]
Expand Down
26 changes: 18 additions & 8 deletions optimum_benchmark/benchmark/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,26 @@ def __post_init__(self):
self.efficiency = Efficiency(**self.efficiency)

@staticmethod
def aggregate(measurements: List["TargetMeasurements"]) -> "TargetMeasurements":
def aggregate_across_processes(measurements: List["TargetMeasurements"]) -> "TargetMeasurements":
assert len(measurements) > 0, "No measurements to aggregate"

m0 = measurements[0]

memory = Memory.aggregate([m.memory for m in measurements]) if m0.memory is not None else None
latency = Latency.aggregate([m.latency for m in measurements]) if m0.latency is not None else None
throughput = Throughput.aggregate([m.throughput for m in measurements]) if m0.throughput is not None else None
energy = Energy.aggregate([m.energy for m in measurements]) if m0.energy is not None else None
efficiency = Efficiency.aggregate([m.efficiency for m in measurements]) if m0.efficiency is not None else None
memory = Memory.aggregate_across_processes([m.memory for m in measurements]) if m0.memory is not None else None
latency = (
Latency.aggregate_across_processes([m.latency for m in measurements]) if m0.latency is not None else None
)
throughput = (
Throughput.aggregate_across_processes([m.throughput for m in measurements])
if m0.throughput is not None
else None
)
energy = Energy.aggregate_across_processes([m.energy for m in measurements]) if m0.energy is not None else None
efficiency = (
Efficiency.aggregate_across_processes([m.efficiency for m in measurements])
if m0.efficiency is not None
else None
)

return TargetMeasurements(
memory=memory, latency=latency, throughput=throughput, energy=energy, efficiency=efficiency
Expand Down Expand Up @@ -99,11 +109,11 @@ def __post_init__(self):
setattr(self, target, TargetMeasurements(**getattr(self, target)))

@classmethod
def aggregate(cls, reports: List["BenchmarkReport"]) -> "BenchmarkReport":
def aggregate_across_processes(cls, reports: List["BenchmarkReport"]) -> "BenchmarkReport":
aggregated_measurements = {}
for target in reports[0].to_dict().keys():
measurements = [getattr(report, target) for report in reports]
aggregated_measurements[target] = TargetMeasurements.aggregate(measurements)
aggregated_measurements[target] = TargetMeasurements.aggregate_across_processes(measurements)

return cls.from_dict(aggregated_measurements)

Expand Down
2 changes: 0 additions & 2 deletions optimum_benchmark/generators/task_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,4 @@ def __call__(self):
"image-text-to-text": ImageTextToTextGenerator,
# diffusers pipelines tasks
"text-to-image": PromptGenerator,
"stable-diffusion": PromptGenerator,
"stable-diffusion-xl": PromptGenerator,
}
2 changes: 1 addition & 1 deletion optimum_benchmark/launchers/torchrun/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any])
raise RuntimeError(f"Received an unexpected response from isolated process: {output}")

self.logger.info("\t+ Aggregating reports from all rank processes")
report = BenchmarkReport.aggregate(reports)
report = BenchmarkReport.aggregate_across_processes(reports)
return report


Expand Down
Loading

0 comments on commit 13bc8c0

Please sign in to comment.