Skip to content

Commit

Permalink
Add the ability to profile test runs using cProfile. (google#921)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 293921287
  • Loading branch information
arsharma1 authored Mar 2, 2020
1 parent 41df40f commit 195d005
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 54 deletions.
38 changes: 23 additions & 15 deletions openhtf/core/phase_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ class PhaseExecutorThread(threads.KillableThread):
"""
daemon = True

def __init__(self, phase_desc, test_state):
def __init__(self, phase_desc, test_state, run_with_profiling):
super(PhaseExecutorThread, self).__init__(
name='<PhaseExecutorThread: (phase_desc.name)>')
name='<PhaseExecutorThread: (phase_desc.name)>',
run_with_profiling=run_with_profiling)
self._phase_desc = phase_desc
self._test_state = test_state
self._phase_execution_outcome = None
Expand Down Expand Up @@ -208,39 +209,44 @@ def __init__(self, test_state):
self._current_phase_thread = None
self._stopping = threading.Event()

def execute_phase(self, phase):
def execute_phase(self, phase, run_with_profiling=False):
"""Executes a phase or skips it, yielding PhaseExecutionOutcome instances.
Args:
phase: Phase to execute.
run_with_profiling: Whether to run with cProfile stat collection for the
phase code run inside a thread.
Returns:
The final PhaseExecutionOutcome that wraps the phase return value
(or exception) of the final phase run. All intermediary results, if any,
are REPEAT and handled internally. Returning REPEAT here means the phase
hit its limit for repetitions.
A two-tuple; the first item is the final PhaseExecutionOutcome that wraps
the phase return value (or exception) of the final phase run. All
intermediary results, if any, are REPEAT and handled internally. Returning
REPEAT here means the phase hit its limit for repetitions.
The second tuple item is the profiler Stats object if profiling was
requested and successfully ran for this phase execution.
"""
repeat_count = 1
repeat_limit = phase.options.repeat_limit or sys.maxsize
while not self._stopping.is_set():
is_last_repeat = repeat_count >= repeat_limit
phase_execution_outcome = self._execute_phase_once(phase, is_last_repeat)
phase_execution_outcome, profile_stats = self._execute_phase_once(
phase, is_last_repeat, run_with_profiling)

if phase_execution_outcome.is_repeat and not is_last_repeat:
repeat_count += 1
continue

return phase_execution_outcome
return phase_execution_outcome, profile_stats
# We've been cancelled, so just 'timeout' the phase.
return PhaseExecutionOutcome(None)
return PhaseExecutionOutcome(None), None

def _execute_phase_once(self, phase_desc, is_last_repeat):
def _execute_phase_once(self, phase_desc, is_last_repeat, run_with_profiling):
"""Executes the given phase, returning a PhaseExecutionOutcome."""
# Check this before we create a PhaseState and PhaseRecord.
if phase_desc.options.run_if and not phase_desc.options.run_if():
_LOG.debug('Phase %s skipped due to run_if returning falsey.',
phase_desc.name)
return PhaseExecutionOutcome(openhtf.PhaseResult.SKIP)
return PhaseExecutionOutcome(openhtf.PhaseResult.SKIP), None

override_result = None
with self.test_state.running_phase_context(phase_desc) as phase_state:
Expand All @@ -256,8 +262,9 @@ def _execute_phase_once(self, phase_desc, is_last_repeat):
# Killed result.
result = PhaseExecutionOutcome(threads.ThreadTerminationError())
phase_state.result = result
return result
phase_thread = PhaseExecutorThread(phase_desc, self.test_state)
return result, None
phase_thread = PhaseExecutorThread(phase_desc, self.test_state,
run_with_profiling)
phase_thread.start()
self._current_phase_thread = phase_thread

Expand All @@ -273,7 +280,8 @@ def _execute_phase_once(self, phase_desc, is_last_repeat):
result = override_result or phase_state.result
_LOG.debug('Phase %s finished with result %s', phase_desc.name,
result.phase_result)
return result
return (result,
phase_thread.get_profile_stats() if run_with_profiling else None)

def reset_stop(self):
self._stopping.clear()
Expand Down
17 changes: 13 additions & 4 deletions openhtf/core/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,16 @@ def _get_running_test_descriptor(self):
teardown=[teardown_phase]),
self._test_desc.code_info, self._test_desc.metadata)

def execute(self, test_start=None):
def execute(self, test_start=None, profile_filename=None):
"""Starts the framework and executes the given test.
Args:
test_start: Either a trigger phase for starting the test, or a function
that returns a DUT ID. If neither is provided, defaults to not
setting the DUT ID.
that returns a DUT ID. If neither is provided, defaults to not
setting the DUT ID.
profile_filename: Name of file to put profiling stats into. This also
enables profiling data collection.
Returns:
Boolean indicating whether the test failed (False) or passed (True).
Expand Down Expand Up @@ -309,7 +312,11 @@ def trigger_phase(test):

test_desc = self._get_running_test_descriptor()
self._executor = test_executor.TestExecutor(
test_desc, self.make_uid(), trigger, self._test_options)
test_desc,
self.make_uid(),
trigger,
self._test_options,
run_with_profiling=profile_filename is not None)

_LOG.info('Executing test: %s', self.descriptor.code_info.name)
self.TEST_INSTANCES[self.uid] = self
Expand All @@ -328,6 +335,8 @@ def trigger_phase(test):

_LOG.debug('Test completed for %s, outputting now.',
final_state.test_record.metadata['test_name'])
test_executor.CombineProfileStats(self._executor.phase_profile_stats,
profile_filename)
for output_cb in self._test_options.output_callbacks:
try:
output_cb(final_state.test_record)
Expand Down
44 changes: 40 additions & 4 deletions openhtf/core/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""TestExecutor executes tests."""

import logging
import pstats
import sys
import tempfile
import threading

from openhtf.core import phase_descriptor
Expand Down Expand Up @@ -46,13 +48,31 @@ class TestStopError(Exception):
"""Test is being stopped."""


def CombineProfileStats(profile_stats_iter, output_filename):
"""Given an iterable of pstats.Stats, combine them into a single Stats."""
profile_stats_filenames = []
for profile_stats in profile_stats_iter:
with tempfile.NamedTemporaryFile(delete=False) as f:
profile_stats_filename = f.name
profile_stats.dump_stats(profile_stats_filename)
profile_stats_filenames.append(profile_stats_filename)
if profile_stats_filenames:
pstats.Stats(*profile_stats_filenames).dump_stats(output_filename)


# pylint: disable=too-many-instance-attributes
class TestExecutor(threads.KillableThread):
"""Encompasses the execution of a single test."""
daemon = True

def __init__(self, test_descriptor, execution_uid, test_start, test_options):
super(TestExecutor, self).__init__(name='TestExecutorThread')
def __init__(self,
test_descriptor,
execution_uid,
test_start,
test_options,
run_with_profiling):
super(TestExecutor, self).__init__(
name='TestExecutorThread', run_with_profiling=run_with_profiling)
self.test_state = None

self._test_descriptor = test_descriptor
Expand All @@ -65,6 +85,12 @@ def __init__(self, test_descriptor, execution_uid, test_start, test_options):
self._abort = threading.Event()
self._full_abort = threading.Event()
self._teardown_phases_lock = threading.Lock()
self._phase_profile_stats = [] # Populated if profiling is enabled.

@property
def phase_profile_stats(self):
"""Returns iterable of profiling Stats objects, per phase."""
return self._phase_profile_stats

def close(self):
"""Close and remove any global registrations.
Expand Down Expand Up @@ -194,7 +220,12 @@ def _execute_test_start(self):
phase_plug.cls for phase_plug in self._test_start.plugs]):
return True

outcome = self._phase_exec.execute_phase(self._test_start)
outcome, profile_stats = self._phase_exec.execute_phase(
self._test_start, self._run_with_profiling)

if profile_stats is not None:
self._phase_profile_stats.append(profile_stats)

if outcome.is_terminal:
self._last_outcome = outcome
return True
Expand Down Expand Up @@ -239,7 +270,12 @@ def _handle_phase(self, phase):
return self._execute_phase_group(phase)

self.test_state.state_logger.debug('Handling phase %s', phase.name)
outcome = self._phase_exec.execute_phase(phase)
outcome, profile_stats = self._phase_exec.execute_phase(
phase, self._run_with_profiling)

if profile_stats is not None:
self._phase_profile_stats.append(profile_stats)

if (self.test_state.test_options.stop_on_first_failure or
conf.stop_on_first_failure):
# Stop Test on first measurement failure
Expand Down
1 change: 0 additions & 1 deletion openhtf/output/callbacks/mfg_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,4 +262,3 @@ def _converter(test_record_obj):
def __call__(self, test_record_obj): # pylint: disable=invalid-name
upload_callback = self.upload()
upload_callback(test_record_obj)

3 changes: 2 additions & 1 deletion openhtf/util/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def _handle_phase(self, phase_desc):
phase_executor.PhaseExecutorThread, '_log_exception',
side_effect=logging.exception):
# Use _execute_phase_once because we want to expose all possible outcomes.
executor._execute_phase_once(phase_desc, is_last_repeat=False)
executor._execute_phase_once(
phase_desc, is_last_repeat=False, run_with_profiling=False)
return test_state_.test_record.phases[-1]

def _handle_test(self, test):
Expand Down
28 changes: 28 additions & 0 deletions openhtf/util/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
"""Thread library defining a few helpers."""

import contextlib
import cProfile
import ctypes
import functools
import logging
import pstats
import sys
import threading

Expand All @@ -36,6 +38,10 @@ class ThreadTerminationError(SystemExit):
"""Sibling of SystemExit, but specific to thread termination."""


class InvalidUsageError(Exception):
"""Raised when an API is used in an invalid or unsupported manner."""


def safe_lock_release_context(rlock):
if six.PY2:
return _safe_lock_release_py2(rlock)
Expand Down Expand Up @@ -141,15 +147,28 @@ class is meant to be subclassed. If you were to invoke this with
"""

def __init__(self, *args, **kwargs):
"""Initializer for KillableThread.
Args:
run_with_profiling: Whether to run this thread with profiling data
collection. Must be passed by keyword.
"""
self._run_with_profiling = kwargs.pop('run_with_profiling', None)
super(KillableThread, self).__init__(*args, **kwargs)
self._running_lock = threading.Lock()
self._killed = threading.Event()
if self._run_with_profiling:
self._profiler = cProfile.Profile()
else:
self._profiler = None

def run(self):
try:
with self._running_lock:
if self._killed.is_set():
raise ThreadTerminationError()
if self._profiler is not None:
self._profiler.enable()
self._thread_proc()
except Exception: # pylint: disable=broad-except
if not self._thread_exception(*sys.exc_info()):
Expand All @@ -158,6 +177,15 @@ def run(self):
finally:
self._thread_finished()
_LOG.debug('Thread finished: %s', self.name)
if self._profiler is not None:
self._profiler.disable()

def get_profile_stats(self):
"""Returns profile_stats from profiler. Raises if profiling not enabled."""
if self._profiler is not None:
return pstats.Stats(self._profiler)
raise InvalidUsageError(
'Profiling not enabled via __init__, or thread has not run yet.')

def _is_thread_proc_running(self):
# Acquire the lock without blocking, though this object is fully implemented
Expand Down
Loading

0 comments on commit 195d005

Please sign in to comment.