Skip to content

Commit

Permalink
extract text generation puts from forward inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 20, 2024
1 parent 3211d16 commit 176c1e6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
45 changes: 24 additions & 21 deletions optimum_benchmark/benchmarks/inference/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..base import Benchmark
from ..report import BenchmarkMeasurements, BenchmarkReport
from .config import InferenceConfig
from .inputs_utils import extract_text_generation_inputs

if is_torch_distributed_available():
import torch.distributed
Expand Down Expand Up @@ -80,26 +81,28 @@ def run(self, backend: Backend[BackendConfigT]) -> None:

if backend.config.task in TEXT_GENERATION_TASKS:
LOGGER.info("\t+ Generating and preparing Text Generation inputs")
self.text_generation_inputs = self.input_generator()
self.text_generation_inputs = backend.prepare_inputs(self.text_generation_inputs)
self.text_generation_inputs = {"input_ids": self.text_generation_inputs["input_ids"]}
self.forward_inputs = self.input_generator()
self.forward_inputs = backend.prepare_inputs(self.forward_inputs)
self.generate_inputs = extract_text_generation_inputs(self.forward_inputs)
LOGGER.info("\t+ Updating Text Generation kwargs with default values")
self.config.generate_kwargs = {**TEXT_GENERATION_KWARGS, **self.config.generate_kwargs}
LOGGER.info("\t+ Initializing Text Generation report")
self.report = TextGenerationReport(prefill=BenchmarkMeasurements(), decode=BenchmarkMeasurements())

elif backend.config.task in IMAGE_DIFFUSION_TASKS:
LOGGER.info("\t+ Generating Image Diffusion inputs")
self.image_diffusion_inputs = self.input_generator()
self.call_inputs = self.input_generator()
self.call_inputs = backend.prepare_inputs(self.call_inputs)
self.call_inputs = {"prompt": self.call_inputs["prompt"]}
LOGGER.info("\t+ Updating Image Diffusion kwargs with default values")
self.config.call_kwargs = {**IMAGE_DIFFUSION_KWARGS, **self.config.call_kwargs}
LOGGER.info("\t+ Initializing Image Diffusion report")
self.report = ImageDiffusionReport(call=BenchmarkMeasurements())

else:
LOGGER.info("\t+ Generating and preparing Inference inputs")
self.inference_inputs = self.input_generator()
self.inference_inputs = backend.prepare_inputs(self.inference_inputs)
self.forward_inputs = self.input_generator()
self.forward_inputs = backend.prepare_inputs(self.forward_inputs)
LOGGER.info("\t+ Initializing Inference report")
self.report = InferenceReport(forward=BenchmarkMeasurements())

Expand All @@ -115,11 +118,11 @@ def run(self, backend: Backend[BackendConfigT]) -> None:
LOGGER.info("\t+ Warming up backend for Inference")
for _ in range(self.config.warmup_runs):
if backend.config.task in TEXT_GENERATION_TASKS:
_ = backend.generate(self.text_generation_inputs, {"max_new_tokens": 2, "min_new_tokens": 2})
_ = backend.generate(self.generate_inputs, {"max_new_tokens": 2, "min_new_tokens": 2})
elif backend.config.task in IMAGE_DIFFUSION_TASKS:
_ = backend.call(self.image_diffusion_inputs, {"num_inference_steps": 2})
_ = backend.call(self.call_inputs, {"num_inference_steps": 2})
else:
_ = backend.forward(self.inference_inputs, self.config.forward_kwargs)
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

if self.config.memory:
LOGGER.info("\t+ Creating inference memory tracker")
Expand Down Expand Up @@ -166,29 +169,29 @@ def run_text_generation_memory_tracking(self, backend: Backend):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker.reset()
with self.memory_tracker.track():
_ = backend.forward(self.text_generation_inputs, self.config.forward_kwargs)
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

self.report.prefill.memory = self.memory_tracker.get_max_memory()

self.memory_tracker.reset()
with self.memory_tracker.track():
_ = backend.generate(self.text_generation_inputs, self.config.generate_kwargs)
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)

self.report.decode.memory = self.memory_tracker.get_max_memory()

def run_image_diffusion_memory_tracking(self, backend: Backend):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker.reset()
with self.memory_tracker.track():
_ = backend.call(self.image_diffusion_inputs, self.config.call_kwargs)
_ = backend.call(self.call_inputs, self.config.call_kwargs)

self.report.call.memory = self.memory_tracker.get_max_memory()

def run_inference_memory_tracking(self, backend: Backend):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker.reset()
with self.memory_tracker.track():
_ = backend.forward(self.inference_inputs, self.config.forward_kwargs)
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

self.report.forward.memory = self.memory_tracker.get_max_memory()

Expand All @@ -198,7 +201,7 @@ def run_text_generation_latency_tracking(self, backend: Backend):
self.latency_tracker.reset()
while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.forward(self.text_generation_inputs, self.config.forward_kwargs)
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

forward_latency = self.latency_tracker.get_latency()
forward_latency.log(prefix="forward")
Expand All @@ -210,7 +213,7 @@ def run_text_generation_latency_tracking(self, backend: Backend):
self.latency_tracker.reset()
while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.generate(self.text_generation_inputs, self.config.generate_kwargs)
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)

generate_latency = self.latency_tracker.get_latency()
generate_latency.log(prefix="generate")
Expand All @@ -224,7 +227,7 @@ def run_image_diffusion_latency_tracking(self, backend: Backend):
self.latency_tracker.reset()
while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.call(self.image_diffusion_inputs, self.config.call_kwargs)
_ = backend.call(self.call_inputs, self.config.call_kwargs)

self.report.call.latency = self.latency_tracker.get_latency()
self.report.call.throughput = Throughput.from_latency(
Expand All @@ -236,7 +239,7 @@ def run_latency_inference_tracking(self, backend: Backend):
self.latency_tracker.reset()
while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.forward(self.inference_inputs, self.config.forward_kwargs)
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

self.report.forward.latency = self.latency_tracker.get_latency()
self.report.forward.throughput = Throughput.from_latency(
Expand All @@ -248,7 +251,7 @@ def run_text_generation_energy_tracking(self, backend: Backend):
LOGGER.info("\t+ Running energy tracking")
self.energy_tracker.reset()
with self.energy_tracker.track():
_ = backend.forward(self.text_generation_inputs, self.config.forward_kwargs)
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

self.report.prefill.energy = self.energy_tracker.get_energy()
self.report.prefill.efficiency = Efficiency.from_energy(
Expand All @@ -257,7 +260,7 @@ def run_text_generation_energy_tracking(self, backend: Backend):

self.energy_tracker.reset()
with self.energy_tracker.track():
_ = backend.generate(self.text_generation_inputs, self.config.generate_kwargs)
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)

self.report.decode.energy = self.energy_tracker.get_energy() - self.report.prefill.energy
self.report.decode.efficiency = Efficiency.from_energy(
Expand All @@ -268,7 +271,7 @@ def run_image_diffusion_energy_tracking(self, backend: Backend):
LOGGER.info("\t+ Running energy tracking")
self.energy_tracker.reset()
with self.energy_tracker.track():
_ = backend.call(self.image_diffusion_inputs, self.config.call_kwargs)
_ = backend.call(self.call_inputs, self.config.call_kwargs)

self.report.call.energy = self.energy_tracker.get_energy()
self.report.call.efficiency = Efficiency.from_energy(
Expand All @@ -279,7 +282,7 @@ def run_inference_energy_tracking(self, backend: Backend):
LOGGER.info("\t+ Running energy tracking")
self.energy_tracker.reset()
with self.energy_tracker.track():
_ = backend.forward(self.inference_inputs, self.config.forward_kwargs)
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

self.report.forward.energy = self.energy_tracker.get_energy()
self.report.forward.efficiency = Efficiency.from_energy(
Expand Down
17 changes: 17 additions & 0 deletions optimum_benchmark/benchmarks/inference/inputs_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def extract_text_generation_inputs(inputs):
if "pixel_values" in inputs:
# image input
text_generation_inputs = {"inputs": inputs["pixel_values"]}
elif "input_values" in inputs:
# speech input
text_generation_inputs = {"inputs": inputs["input_values"]}
elif "input_features" in inputs:
# waveform input
text_generation_inputs = {"inputs": inputs["input_features"]}
elif "input_ids" in inputs:
# text input
text_generation_inputs = {"inputs": inputs["input_ids"]}
else:
raise ValueError("Could not find any valid text generation inputs.")

return text_generation_inputs

0 comments on commit 176c1e6

Please sign in to comment.