Skip to content

Commit

Permalink
WIP: unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
joostvanzwieten committed May 11, 2024
1 parent 256c05b commit 525dbb5
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ jobs:
python -um pip install --upgrade --upgrade-strategy eager mkl
python -um devtools.gha.configure_mkl
- name: Test
run: python -um coverage run -m unittest discover -b -q -t . -s tests
#run: python -um coverage run -m unittest discover -b -q -t . -s tests
run: python -um devtools.gha.unittest
- name: Post-process coverage
run: python -um devtools.gha.coverage_report_xml
- name: Upload coverage
Expand Down
229 changes: 229 additions & 0 deletions devtools/gha/unittest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import unittest, pathlib, multiprocessing, sys, ctypes, io, traceback, contextlib, tempfile, os

def _unpack_test_suites(suite):
if type(suite) == unittest.TestSuite:
for test in suite:
yield from _unpack_test_suites(test)
else:
yield suite

# TODO: keep tests in modules or classes with fixtures together


class TestSuite(unittest.TestSuite):

def __init__(self, tests, shared_iter):
super().__init__()
del self._tests
self.__tests = tests
self.__shared_iter = shared_iter

def __iter__(self):
for index in self.__shared_iter:
yield self.__tests[index]

def countTestCases(self):
raise NotImplementedError

def addTest(self, test):
raise NotImplementedError

def addTests(self, tests):
if tests:
raise NotImplementedError

def _removeTestAtIndex(self, index):
pass


class SharedRange:

def __init__(self, length):
self._length = length
self._index = multiprocessing.Value(ctypes.c_int)
with self._index.get_lock():
self._index.value = 0

def __iter__(self):
return self

def __next__(self):
with self._index.get_lock():
index = self._index.value
if index >= self._length:
raise StopIteration
self._index.value += 1
return index


class TestResult:

def __init__(self, *, ntests, stream, failfast=False, withtraceback=True, errors=None, skips=None):
self._stream = stream
self._lock = multiprocessing.RLock()
for attr in 'done', 'failed', 'errored', 'skipped':
cnt = multiprocessing.Value(ctypes.c_int, lock=False)
cnt.value = 0
setattr(self, '_'+attr, cnt)
self._shouldStop = multiprocessing.Value(ctypes.c_bool, lock=False)
self._ntests = ntests
self._fmt = '\r\033[K{{:{0}}}/{1} {{:{0}}} F {{:{0}}} E {{:{0}}} S'.format(len(str(self._ntests)), self._ntests)
self.failfast = failfast
self.withtraceback = withtraceback
self.buffer = True
self._errors = {} if errors is None else errors
self._skips = {} if skips is None else skips
self._original_stdout = sys.stdout
self._original_stderr = sys.stderr

@property
def shouldStop(self):
return self._shouldStop.value

@shouldStop.setter
def shouldStop(self, value):
with self._lock:
self._shouldStop.value = bool(value)

def stop(self):
self.shouldStop = True

@property
def testsRun(self):
with self._lock:
return self._done.value

@property
def wasSuccessful(self):
with self._lock:
return not self._failed.value and not self._errored.value

def startTest(self, test):
self._stderr_buffer = io.StringIO()
self._stdout_buffer = io.StringIO()
sys.stdout = self._stdout_buffer
sys.stderr = self._stderr_buffer

def stopTest(self, test):
sys.stdout = self._original_stdout
sys.stderr = self._original_stderr
with self._lock:
self._done.value += 1
self._report_locked(test)

def startTestRun(self):
raise NotImplementedError

def stopTestRun(self):
raise NotImplementedError

def addSuccess(self, test):
pass

def addError(self, test, err):
tb = ''.join(traceback.format_exception(*err))
with self._lock:
if self.withtraceback:
sameas = self._errors.get(tb)
if not sameas:
self._errors[tb] = str(test)
self._stream.write('\r\033[KERRORED: {}\n'.format(test))
self._stream.write(tb)
self._stream.write('\n')
#else:
# self._stream.write('\r\033[KERRORED: {} (same traceback as {})\n'.format(test, sameas))
else:
self._stream.write('\r\033[KERRORED: {}\n'.format(test))
self._stream.flush()
self._errored.value += 1

def addFailure(self, test, err):
with self._lock:
errmsg = str(err[1])
self._stream.write('\r\033[KFAILED: {}{}{}\n'.format(test, '\n' if '\n' in errmsg else ': ', errmsg))
self._stream.flush()
self._failed.value += 1

def addSkip(self, test, reason):
with self._lock:
self._skipped.value += 1
self._skips[reason] = self._skips.get(reason, 0) + 1

def addExpectedFailure(self, test, err):
with self._lock:
self._success.value += 1

def addUnexpectedSuccess(self, test):
with self._lock:
self._failed.value += 1

def addSubTest(self, test, subtest, err):
if err is not None:
if issubclass(err[0], test.failureException):
self.addFailure(subtest, err)
else:
self.addError(subtest, err)
with self._lock:
self._report_locked(test)

def _report_locked(self, test):
self._stream.write(self._fmt.format(self._done.value, self._failed.value, self._errored.value, self._skipped.value))
self._stream.flush()


if __name__ == '__main__':
root = pathlib.Path().resolve()
sys.path.insert(0, str(root))
withcoverage = True
withtraceback = True
failfast = '-f' in sys.argv

with contextlib.ExitStack() as stack, multiprocessing.Manager() as manager:

if withcoverage:
import coverage
tmpdir = stack.enter_context(tempfile.TemporaryDirectory())
cov = coverage.Coverage()
cov.start()

loader = unittest.TestLoader()
suite = loader.discover(str(root))
#suite = loader.discover('tests.test_function', top_level_dir=str(root))
tests = tuple(_unpack_test_suites(suite))

r = SharedRange(len(tests))
s = TestSuite(tests, r)

skips = manager.dict()
result = TestResult(stream=sys.stderr, ntests=suite.countTestCases(), failfast=failfast, withtraceback=withtraceback, errors=manager.dict(), skips=skips)

def run():
if withcoverage:
wcov = coverage.Coverage(data_file=os.path.join(tmpdir, '.coverage.{}'.format(os.getpid())))
wcov.start()
s.run(result)
if withcoverage:
wcov.stop()
wcov.save()

workers = []
for i in range(multiprocessing.cpu_count()):
w = multiprocessing.Process(target=run)
w.start()
workers.append(w)
for w in workers:
w.join()

sys.stderr.write('\n')
sys.stderr.flush()

if skips:
print('SKIP REASONS:', file=sys.stderr)
for reason, cnt in sorted(skips.items(), key=lambda item: (-item[1], item[0])):
print(f'{cnt:4d} {reason}', file=sys.stderr)

if withcoverage:
cov.stop()
cov.combine([os.path.join(tmpdir, f) for f in os.listdir(tmpdir)])
cov.save()
cov.report()

0 comments on commit 525dbb5

Please sign in to comment.