diff --git a/splunklib/searchcommands/search_command.py b/splunklib/searchcommands/search_command.py index a093b7d96..b765778e3 100644 --- a/splunklib/searchcommands/search_command.py +++ b/splunklib/searchcommands/search_command.py @@ -630,7 +630,7 @@ def _process_protocol_v1(self, argv, ifile, ofile): except: self._report_unexpected_error() self.flush() - exit(1) + raise debug('%s.process finished under protocol_version=1', class_name) @@ -688,7 +688,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): self._record_writer = RecordWriterV2(ofile) self._report_unexpected_error() self.finish() - exit(1) + raise # Write search command configuration for consumption by splunkd # noinspection PyBroadException @@ -768,7 +768,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): self._record_writer.write_metadata(self._configuration) self._report_unexpected_error() self.finish() - exit(1) + raise self._record_writer.write_metadata(self._configuration) @@ -784,7 +784,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): except: self._report_unexpected_error() self.finish() - exit(1) + raise debug('%s.process completed', class_name) diff --git a/tests/searchcommands/test_search_command.py b/tests/searchcommands/test_search_command.py index 386b5bb65..ccf246c19 100755 --- a/tests/searchcommands/test_search_command.py +++ b/tests/searchcommands/test_search_command.py @@ -35,6 +35,9 @@ import pytest +class CustomTestCommandException(Exception): + pass + @Configuration() class TestCommand(SearchCommand): @@ -44,10 +47,7 @@ class TestCommand(SearchCommand): def echo(self, records): for record in records: if record.get('action') == 'raise_exception': - if six.PY2: - raise StandardError(self) - else: - raise Exception(self) + raise CustomTestCommandException() yield record def _execute(self, ifile, process): @@ -83,6 +83,8 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): # endregion +class CustomRuntimeError(RuntimeError): + pass @Configuration() class TestStreamingCommand(StreamingCommand): @@ -91,7 +93,7 @@ def stream(self, records): for record in records: action = record['action'] if action == 'raise_error': - raise RuntimeError('Testing') + raise CustomRuntimeError('Testing') value = self.search_results_info if action == 'get_search_results_info' else None yield {'_serial': serial_number, 'data': value} serial_number += 1 @@ -126,7 +128,7 @@ def test_process_scpv1(self): command = TestCommand() result = BytesIO() - self.assertRaises(SystemExit, command.process, argv, ofile=result) + self.assertRaises(RuntimeError, command.process, argv, ofile=result) self.assertRegexpMatches(result.getvalue().decode('UTF-8'), expected) # TestCommand.process should return configuration settings on Getinfo probe @@ -292,15 +294,14 @@ def test_process_scpv1(self): try: # noinspection PyTypeChecker command.process(argv, ifile, ofile=result) - except SystemExit as error: - self.assertNotEqual(error.code, 0) + except CustomRuntimeError: self.assertRegexpMatches( result.getvalue().decode('UTF-8'), r'^error_message=RuntimeError at ".+", line \d+ : Testing\r\n\r\n$') except BaseException as error: - self.fail('Expected SystemExit, but caught {}: {}'.format(type(error).__name__, error)) + self.fail('Expected CustomRuntimeError, but caught {}: {}'.format(type(error).__name__, error)) else: - self.fail('Expected SystemExit, but no exception was raised') + self.fail('Expected CustomRuntimeError, but no exception was raised') # Command.process should provide access to search results info info_path = os.path.join( @@ -676,12 +677,12 @@ def test_process_scpv2(self): try: command.process(argv, ifile, ofile=result) - except SystemExit as error: - self.assertNotEqual(0, error.code) + except CustomTestCommandException: + pass # Expected exception was preserved. except BaseException as error: self.fail('{0}: {1}: {2}\n'.format(type(error).__name__, error, result.getvalue().decode('utf-8'))) else: - self.fail('Expected SystemExit, not a return from TestCommand.process: {}\n'.format(result.getvalue().decode('utf-8'))) + self.fail('Expected CustomTestCommandException, not a return from TestCommand.process: {}\n'.format(result.getvalue().decode('utf-8'))) self.assertEqual(command.logging_configuration, logging_configuration) self.assertEqual(command.logging_level, logging_level) @@ -692,17 +693,13 @@ def test_process_scpv2(self): finished = r'\"finished\":true' - if six.PY2: - inspector = \ - r'\"inspector\":\{\"messages\":\[\[\"ERROR\",\"StandardError at \\\".+\\\", line \d+ : test ' \ - r'logging_configuration=\\\".+\\\" logging_level=\\\"WARNING\\\" record=\\\"f\\\" ' \ - r'required_option_1=\\\"value_1\\\" required_option_2=\\\"value_2\\\" show_configuration=\\\"f\\\"\"\]\]\}' - else: - inspector = \ - r'\"inspector\":\{\"messages\":\[\[\"ERROR\",\"Exception at \\\".+\\\", line \d+ : test ' \ - r'logging_configuration=\\\".+\\\" logging_level=\\\"WARNING\\\" record=\\\"f\\\" ' \ - r'required_option_1=\\\"value_1\\\" required_option_2=\\\"value_2\\\" show_configuration=\\\"f\\\"\"\]\]\}' + inspector = \ + r'\"inspector\":\{\"messages\":\[\[\"ERROR\",\"CustomTestCommandException at \\\".+\\\", line \d+ : test ' \ + r'logging_configuration=\\\".+\\\" logging_level=\\\"WARNING\\\" record=\\\"f\\\" ' \ + r'required_option_1=\\\"value_1\\\" required_option_2=\\\"value_2\\\" show_configuration=\\\"f\\\"\"\]\]\}' + res = result.getvalue().decode('utf-8') + print('result = '+res) self.assertRegexpMatches( result.getvalue().decode('utf-8'), r'^chunked 1.0,2,0\n'