diff --git a/openhtf/core/phase_executor.py b/openhtf/core/phase_executor.py index ddb7de0de..a35866969 100644 --- a/openhtf/core/phase_executor.py +++ b/openhtf/core/phase_executor.py @@ -34,6 +34,7 @@ import collections import logging +import traceback import openhtf from openhtf.util import argv @@ -50,6 +51,16 @@ _LOG = logging.getLogger(__name__) +class ExceptionInfo(collections.namedtuple( + 'ExceptionInfo', ['exc_type', 'exc_val', 'exc_tb'])): + def _asdict(self): + return { + 'exc_type': str(self.exc_type), + 'exc_val': self.exc_val, + 'exc_tb': ''.join(traceback.format_exception(*self)), + } + + class InvalidPhaseResultError(Exception): """Raised when a PhaseOutcome is created with an invalid phase result.""" @@ -75,7 +86,7 @@ class PhaseOutcome(collections.namedtuple('PhaseOutcome', 'phase_result')): """ def __init__(self, phase_result): if (phase_result is not None and - not isinstance(phase_result, (openhtf.PhaseResult, Exception)) and + not isinstance(phase_result, (openhtf.PhaseResult, ExceptionInfo)) and not isinstance(phase_result, threads.ThreadTerminationError)): raise InvalidPhaseResultError('Invalid phase result', phase_result) super(PhaseOutcome, self).__init__(phase_result) @@ -89,7 +100,7 @@ def is_timeout(self): def raised_exception(self): """True if the phase in question raised an exception.""" return isinstance(self.phase_result, ( - Exception, threads.ThreadTerminationError)) + ExceptionInfo, threads.ThreadTerminationError)) @property def is_terminal(self): @@ -123,8 +134,8 @@ def _thread_proc(self): # set to the InvalidPhaseResultError in _thread_exception instead. self._phase_outcome = PhaseOutcome(phase_return) - def _thread_exception(self, exc): - self._phase_outcome = PhaseOutcome(exc) + def _thread_exception(self, *args): + self._phase_outcome = PhaseOutcome(ExceptionInfo(*args)) self._test_state.logger.exception('Phase %s raised an exception', self.name) def join_or_die(self): @@ -166,32 +177,37 @@ def execute_phases(self, phases, teardown_func): Args: phases: List of phases to execute. + teardown_func: Yields: PhaseOutcome instance that wraps the phase return value (or exception). """ - for phase in phases: - while True: - outcome = self._execute_one_phase(phase) - if outcome: - # We have to run the teardown_func *before* we yield the outcome, - # because yielding the outcome results in the state being finalized - # in the case of a terminal outcome. - if outcome.is_terminal and teardown_func: - self._execute_one_phase(teardown_func, output_record=False) - yield outcome - - # If we're done with this phase, skip to the next one. - if outcome.phase_result is openhtf.PhaseResult.CONTINUE: + try: + for phase in phases: + while True: + outcome = self._execute_one_phase(phase) + if outcome: + # We have to run the teardown_func *before* we yield the outcome, + # because yielding the outcome results in the state being finalized + # in the case of a terminal outcome. + if outcome.is_terminal and teardown_func: + self._execute_one_phase(teardown_func) + yield outcome + + # If we're done with this phase, skip to the next one. + if outcome.phase_result is openhtf.PhaseResult.CONTINUE: + break + else: + # run_if was falsey, just skip this phase. break - else: - # run_if was falsey, just skip this phase. - break - # If all phases complete with no terminal outcome, we end up here. - if teardown_func: - self._execute_one_phase(teardown_func, output_record=False) - - def _execute_one_phase(self, phase_desc, output_record=True): + if teardown_func: + self._execute_one_phase(teardown_func) + except (KeyboardInterrupt, SystemExit): + if teardown_func: + self._execute_one_phase(teardown_func) + raise + + def _execute_one_phase(self, phase_desc): """Executes the given phase, returning a PhaseOutcome.""" # Check this before we create a PhaseState and PhaseRecord. if phase_desc.options.run_if and not phase_desc.options.run_if(): @@ -199,8 +215,7 @@ def _execute_one_phase(self, phase_desc, output_record=True): phase_desc.name) return - with self.test_state.running_phase_context( - phase_desc, output_record) as phase_state: + with self.test_state.running_phase_context(phase_desc) as phase_state: _LOG.info('Executing phase %s', phase_desc.name) phase_thread = PhaseExecutorThread(phase_desc, self.test_state) phase_thread.start() diff --git a/openhtf/core/test_state.py b/openhtf/core/test_state.py index 38345a72b..50b952116 100644 --- a/openhtf/core/test_state.py +++ b/openhtf/core/test_state.py @@ -128,7 +128,7 @@ def test_api(self): self.notify_update)) @contextlib.contextmanager - def running_phase_context(self, phase_desc, output_record=True): + def running_phase_context(self, phase_desc): """Create a context within which a single phase is running. Yields a PhaseState object for tracking transient state during the @@ -146,11 +146,10 @@ def running_phase_context(self, phase_desc, output_record=True): self.notify_update() # New phase started. yield self.running_phase_state finally: - if output_record: - # Clear notification callbacks so we can serialize measurements. - for meas in self.running_phase_state.measurements.values(): - meas.set_notification_callback(None) - self.test_record.phases.append(self.running_phase_state.phase_record) + # Clear notification callbacks so we can serialize measurements. + for meas in self.running_phase_state.measurements.values(): + meas.set_notification_callback(None) + self.test_record.phases.append(self.running_phase_state.phase_record) self.running_phase_state = None self.notify_update() # Phase finished. diff --git a/openhtf/util/test.py b/openhtf/util/test.py index 0347edf67..81b302fd1 100644 --- a/openhtf/util/test.py +++ b/openhtf/util/test.py @@ -184,9 +184,10 @@ def _handle_phase(self, phase_desc): try: phase_state.result = phase_executor.PhaseOutcome( phase_desc(test_state_)) - except Exception as exc: # pylint:disable=broad-except + except Exception: # pylint:disable=broad-except logging.exception('Exception executing phase %s', phase_desc.name) - phase_state.result = phase_executor.PhaseOutcome(exc) + phase_state.result = phase_executor.PhaseOutcome( + phase_executor.ExceptionInfo(*sys.exc_info())) return phase_state.phase_record @@ -379,7 +380,7 @@ def assertPhaseError(self, phase_record, exc_type=None): self.assertTrue(phase_record.result.raised_exception, 'Phase did not raise an exception') if exc_type: - self.assertIsInstance(phase_record.result.phase_result, exc_type, + self.assertIsInstance(phase_record.result.phase_result.exc_val, exc_type, 'Raised exception %r is not a subclass of %r' % (phase_record.result.phase_result, exc_type)) diff --git a/openhtf/util/threads.py b/openhtf/util/threads.py index 1718f08ab..e451f16f6 100644 --- a/openhtf/util/threads.py +++ b/openhtf/util/threads.py @@ -18,6 +18,7 @@ import ctypes import functools import logging +import sys import threading _LOG = logging.getLogger(__name__) @@ -57,8 +58,8 @@ class is meant to be subclassed. If you were to invoke this with def run(self): try: self._thread_proc() - except Exception as exception: # pylint: disable=broad-except - if not self._thread_exception(exception): + except Exception: # pylint: disable=broad-except + if not self._thread_exception(*sys.exc_info()): logging.exception('Thread raised an exception: %s', self.name) raise finally: @@ -71,7 +72,7 @@ def _thread_proc(self): def _thread_finished(self): """The method called once _thread_proc has finished.""" - def _thread_exception(self, exception): + def _thread_exception(self, exc_type, exc_val, exc_tb): """The method called if _thread_proc raises an exception. To suppress the exception, return True from this method. @@ -101,9 +102,9 @@ def async_raise(self, exc_type): ctypes.pythonapi.PyThreadState_SetAsyncExc(self.ident, None) raise SystemError('PyThreadState_SetAsyncExc failed.', self.ident) - def _thread_exception(self, exception): + def _thread_exception(self, exc_type, exc_val, exc_tb): """Suppress the exception when we're kill()'d.""" - return isinstance(exception, ThreadTerminationError) + return exc_type is ThreadTerminationError class NoneByDefaultThreadLocal(threading.local): diff --git a/test/core/measurements_record.pickle b/test/core/measurements_record.pickle index 85050ddf5..9cc355c00 100644 Binary files a/test/core/measurements_record.pickle and b/test/core/measurements_record.pickle differ diff --git a/test/plugs_test.py b/test/plugs_test.py index 0e668fe5d..8e15b54ed 100644 --- a/test/plugs_test.py +++ b/test/plugs_test.py @@ -62,6 +62,7 @@ class PlugsTest(test.TestCase): def setUp(self): self.logger = object() self.plug_manager = plugs.PlugManager({AdderPlug}, self.logger) + AdderPlug.INSTANCE_COUNT = 0 def tearDown(self): self.plug_manager.tear_down_plugs()