From 525dbb5e65843e3a4acf684685706770126a250f Mon Sep 17 00:00:00 2001 From: Joost van Zwieten Date: Sat, 11 May 2024 13:43:26 +0200 Subject: [PATCH] WIP: unittest --- .github/workflows/test.yaml | 3 +- devtools/gha/unittest.py | 229 ++++++++++++++++++++++++++++++++++++ 2 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 devtools/gha/unittest.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 229f28821..3b4edf7b5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -108,7 +108,8 @@ jobs: python -um pip install --upgrade --upgrade-strategy eager mkl python -um devtools.gha.configure_mkl - name: Test - run: python -um coverage run -m unittest discover -b -q -t . -s tests + #run: python -um coverage run -m unittest discover -b -q -t . -s tests + run: python -um devtools.gha.unittest - name: Post-process coverage run: python -um devtools.gha.coverage_report_xml - name: Upload coverage diff --git a/devtools/gha/unittest.py b/devtools/gha/unittest.py new file mode 100644 index 000000000..2614a2710 --- /dev/null +++ b/devtools/gha/unittest.py @@ -0,0 +1,229 @@ +import unittest, pathlib, multiprocessing, sys, ctypes, io, traceback, contextlib, tempfile, os + +def _unpack_test_suites(suite): + if type(suite) == unittest.TestSuite: + for test in suite: + yield from _unpack_test_suites(test) + else: + yield suite + +# TODO: keep tests in modules or classes with fixtures together + + +class TestSuite(unittest.TestSuite): + + def __init__(self, tests, shared_iter): + super().__init__() + del self._tests + self.__tests = tests + self.__shared_iter = shared_iter + + def __iter__(self): + for index in self.__shared_iter: + yield self.__tests[index] + + def countTestCases(self): + raise NotImplementedError + + def addTest(self, test): + raise NotImplementedError + + def addTests(self, tests): + if tests: + raise NotImplementedError + + def _removeTestAtIndex(self, index): + pass + + +class SharedRange: + + def __init__(self, length): + self._length = length + self._index = multiprocessing.Value(ctypes.c_int) + with self._index.get_lock(): + self._index.value = 0 + + def __iter__(self): + return self + + def __next__(self): + with self._index.get_lock(): + index = self._index.value + if index >= self._length: + raise StopIteration + self._index.value += 1 + return index + + +class TestResult: + + def __init__(self, *, ntests, stream, failfast=False, withtraceback=True, errors=None, skips=None): + self._stream = stream + self._lock = multiprocessing.RLock() + for attr in 'done', 'failed', 'errored', 'skipped': + cnt = multiprocessing.Value(ctypes.c_int, lock=False) + cnt.value = 0 + setattr(self, '_'+attr, cnt) + self._shouldStop = multiprocessing.Value(ctypes.c_bool, lock=False) + self._ntests = ntests + self._fmt = '\r\033[K{{:{0}}}/{1} {{:{0}}} F {{:{0}}} E {{:{0}}} S'.format(len(str(self._ntests)), self._ntests) + self.failfast = failfast + self.withtraceback = withtraceback + self.buffer = True + self._errors = {} if errors is None else errors + self._skips = {} if skips is None else skips + self._original_stdout = sys.stdout + self._original_stderr = sys.stderr + + @property + def shouldStop(self): + return self._shouldStop.value + + @shouldStop.setter + def shouldStop(self, value): + with self._lock: + self._shouldStop.value = bool(value) + + def stop(self): + self.shouldStop = True + + @property + def testsRun(self): + with self._lock: + return self._done.value + + @property + def wasSuccessful(self): + with self._lock: + return not self._failed.value and not self._errored.value + + def startTest(self, test): + self._stderr_buffer = io.StringIO() + self._stdout_buffer = io.StringIO() + sys.stdout = self._stdout_buffer + sys.stderr = self._stderr_buffer + + def stopTest(self, test): + sys.stdout = self._original_stdout + sys.stderr = self._original_stderr + with self._lock: + self._done.value += 1 + self._report_locked(test) + + def startTestRun(self): + raise NotImplementedError + + def stopTestRun(self): + raise NotImplementedError + + def addSuccess(self, test): + pass + + def addError(self, test, err): + tb = ''.join(traceback.format_exception(*err)) + with self._lock: + if self.withtraceback: + sameas = self._errors.get(tb) + if not sameas: + self._errors[tb] = str(test) + self._stream.write('\r\033[KERRORED: {}\n'.format(test)) + self._stream.write(tb) + self._stream.write('\n') + #else: + # self._stream.write('\r\033[KERRORED: {} (same traceback as {})\n'.format(test, sameas)) + else: + self._stream.write('\r\033[KERRORED: {}\n'.format(test)) + self._stream.flush() + self._errored.value += 1 + + def addFailure(self, test, err): + with self._lock: + errmsg = str(err[1]) + self._stream.write('\r\033[KFAILED: {}{}{}\n'.format(test, '\n' if '\n' in errmsg else ': ', errmsg)) + self._stream.flush() + self._failed.value += 1 + + def addSkip(self, test, reason): + with self._lock: + self._skipped.value += 1 + self._skips[reason] = self._skips.get(reason, 0) + 1 + + def addExpectedFailure(self, test, err): + with self._lock: + self._success.value += 1 + + def addUnexpectedSuccess(self, test): + with self._lock: + self._failed.value += 1 + + def addSubTest(self, test, subtest, err): + if err is not None: + if issubclass(err[0], test.failureException): + self.addFailure(subtest, err) + else: + self.addError(subtest, err) + with self._lock: + self._report_locked(test) + + def _report_locked(self, test): + self._stream.write(self._fmt.format(self._done.value, self._failed.value, self._errored.value, self._skipped.value)) + self._stream.flush() + + +if __name__ == '__main__': + root = pathlib.Path().resolve() + sys.path.insert(0, str(root)) + withcoverage = True + withtraceback = True + failfast = '-f' in sys.argv + + with contextlib.ExitStack() as stack, multiprocessing.Manager() as manager: + + if withcoverage: + import coverage + tmpdir = stack.enter_context(tempfile.TemporaryDirectory()) + cov = coverage.Coverage() + cov.start() + + loader = unittest.TestLoader() + suite = loader.discover(str(root)) + #suite = loader.discover('tests.test_function', top_level_dir=str(root)) + tests = tuple(_unpack_test_suites(suite)) + + r = SharedRange(len(tests)) + s = TestSuite(tests, r) + + skips = manager.dict() + result = TestResult(stream=sys.stderr, ntests=suite.countTestCases(), failfast=failfast, withtraceback=withtraceback, errors=manager.dict(), skips=skips) + + def run(): + if withcoverage: + wcov = coverage.Coverage(data_file=os.path.join(tmpdir, '.coverage.{}'.format(os.getpid()))) + wcov.start() + s.run(result) + if withcoverage: + wcov.stop() + wcov.save() + + workers = [] + for i in range(multiprocessing.cpu_count()): + w = multiprocessing.Process(target=run) + w.start() + workers.append(w) + for w in workers: + w.join() + + sys.stderr.write('\n') + sys.stderr.flush() + + if skips: + print('SKIP REASONS:', file=sys.stderr) + for reason, cnt in sorted(skips.items(), key=lambda item: (-item[1], item[0])): + print(f'{cnt:4d} {reason}', file=sys.stderr) + + if withcoverage: + cov.stop() + cov.combine([os.path.join(tmpdir, f) for f in os.listdir(tmpdir)]) + cov.save() + cov.report()