Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 27, 2024
1 parent b157b89 commit db3c8f3
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions optimum_benchmark/scenarios/inference/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,17 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport:
self.logger.info("\t+ Updating Text Generation kwargs with default values")
self.config.generate_kwargs = {**TEXT_GENERATION_DEFAULT_KWARGS, **self.config.generate_kwargs}
self.logger.info("\t+ Initializing Text Generation report")
self.report = BenchmarkReport.from_list(targets=["load", "prefill", "decode", "per_token"])
self.report = BenchmarkReport.from_list(targets=["load_model", "prefill", "decode", "per_token"])
elif self.backend.config.task in IMAGE_DIFFUSION_TASKS:
self.logger.info("\t+ Updating Image Diffusion kwargs with default values")
self.config.call_kwargs = {**IMAGE_DIFFUSION_DEFAULT_KWARGS, **self.config.call_kwargs}
self.logger.info("\t+ Initializing Image Diffusion report")
self.report = BenchmarkReport.from_list(targets=["load", "call"])
self.report = BenchmarkReport.from_list(targets=["load_model", "call"])
else:
self.logger.info("\t+ Initializing Inference report")
self.report = BenchmarkReport.from_list(targets=["load", "forward"])
self.report = BenchmarkReport.from_list(targets=["load_model", "forward"])

self.run_model_loading_tracking(backend)

self.logger.info("\t+ Creating input generator")
self.input_generator = InputGenerator(
Expand All @@ -83,15 +85,11 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport:
input_shapes=self.config.input_shapes,
model_type=backend.config.model_type,
)

self.logger.info("\t+ Generating inputs")
self.inputs = self.input_generator()

self.logger.info("\t+ Preparing inputs for Inference")
self.logger.info("\t+ Preparing inputs for backend")
self.inputs = backend.prepare_inputs(inputs=self.inputs)

self.run_model_loading_tracking(backend)

if self.config.latency or self.config.energy:
# latency and energy are metrics that require some warmup
if self.config.warmup_runs > 0:
Expand Down Expand Up @@ -159,8 +157,14 @@ def run_model_loading_tracking(self, backend: Backend[BackendConfigT]):
)
if self.config.latency:
latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device)
if self.config.energy:
energy_tracker = EnergyTracker(
backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids
)

with ExitStack() as context_stack:
if self.config.energy:
context_stack.enter_context(energy_tracker.track())
if self.config.memory:
context_stack.enter_context(memory_tracker.track())
if self.config.latency:
Expand All @@ -169,9 +173,11 @@ def run_model_loading_tracking(self, backend: Backend[BackendConfigT]):
backend.load()

if self.config.latency:
self.report.load.latency = latency_tracker.get_latency()
self.report.load_model.latency = latency_tracker.get_latency()
if self.config.memory:
self.report.load.memory = memory_tracker.get_max_memory()
self.report.load_model.memory = memory_tracker.get_max_memory()
if self.config.energy:
self.report.load_model.energy = energy_tracker.get_energy()

## Memory tracking
def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]):
Expand Down

0 comments on commit db3c8f3

Please sign in to comment.