-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest-docs
executable file
·308 lines (259 loc) · 9.9 KB
/
test-docs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
#!/usr/bin/python3
import argparse
import multiprocessing
import contextlib
from collections import defaultdict
import os
import io
import sys
import warnings
import shutil
import logging
import traceback
import tempfile
import ast
from termcolor import cprint
import importlib
import json
import docutils
import docutils.core
import docutils.nodes
log = logging.getLogger()
class Fail(Exception):
pass
class RunEnv:
"""
Test environment for python tests
"""
TEST_BUFR = "extra/bufr/synop-sunshine.bufr"
def __init__(self, entry):
self.entry = entry
self.globals = dict(globals())
self.locals = dict(locals())
self.assigned = set()
self.wanted = set()
self.stdout = io.StringIO()
self.stderr = io.StringIO()
# Import modules expected by the snippets
for modname in ("dballe", "wreport", "os", "datetime"):
self.globals[modname] = importlib.__import__(modname)
def check_reqs(self):
"""
Check if the snippet expects some well-known variables to be defined
"""
tree = ast.parse(self.entry["code"], self.entry["src"], "exec")
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
if len(node.targets) != 1:
continue
if not isinstance(node.targets[0], ast.Name):
continue
self.assigned.add(node.targets[0].id)
elif isinstance(node, ast.Attribute):
if not isinstance(node.value, ast.Name):
continue
if node.value.id in self.assigned:
continue
self.wanted.add(node.value.id)
elif isinstance(node, ast.With):
if len(node.items) != 1:
continue
self.assigned.add(node.items[0].optional_vars.id)
# Provide some assertion methods the snippets can use
def assertEqual(self, a, b):
if a != b:
raise AssertionError(f"{a!r} != {b!r}")
@contextlib.contextmanager
def setup(self, capture_output=True):
import dballe
# Build a test environment in a temporary directory
with tempfile.TemporaryDirectory() as root:
shutil.copy(self.TEST_BUFR, os.path.join(root, "test.bufr"))
db_url = "sqlite:" + os.path.join(root, "db.sqlite")
db = dballe.DB.connect(db_url + "?wipe=yes")
importer = dballe.Importer("BUFR")
with importer.from_file(os.path.join(root, "test.bufr")) as f:
db.import_messages(f)
# Create well-known variables if needed
if "db" in self.wanted or "tr" in self.wanted or "explorer" in self.wanted:
self.locals["db"] = db
if "explorer" in self.wanted:
self.locals["explorer"] = dballe.Explorer()
with self.locals["explorer"].rebuild() as update:
with self.locals["db"].transaction() as tr:
update.add_db(tr)
if "tr" in self.wanted:
self.locals["tr"] = self.locals["db"].transaction()
if "msg" in self.wanted:
importer = dballe.Importer("BUFR")
with importer.from_file("extra/bufr/synop-sunshine.bufr") as f:
for msg in f:
self.locals["msg"] = msg[0]
# Set env variables the snippets may use
os.environ["DBA_DB"] = db_url
# Run the code, in the temp directory, discarding stdout
curdir = os.getcwd()
if capture_output:
orig_stdout = sys.stdout
orig_stderr = sys.stderr
sys.stdout = self.stdout
sys.stderr = self.stderr
os.chdir(root)
try:
yield
finally:
if capture_output:
sys.stderr = orig_stderr
sys.stdout = orig_stdout
os.chdir(curdir)
def store_exc(self):
type, value, tb = sys.exc_info()
for frame, lineno in traceback.walk_tb(tb):
if frame.f_code.co_filename == self.entry["src"]:
self.entry["exception"] = (lineno, repr(value))
if "exception" not in self.entry:
self.entry["exception"] = None, traceback.format_exc()
def run(self):
"""
Run a compiled test snippet
"""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.entry["warnings"] = w
try:
code = compile(self.entry["code"], self.entry["src"], "exec")
except Exception:
self.entry["result"] = "fail-compile"
self.store_exc()
return
self.check_reqs()
with self.setup():
try:
exec(code, self.globals, self.locals)
except Exception:
self.entry["result"] = "fail-run"
self.store_exc()
return
finally:
self.entry["stdout"] = self.stdout.getvalue()
self.entry["stderr"] = self.stderr.getvalue()
self.entry["result"] = "ok"
def run_python_test(entry):
"""
Run a collected python code snippet, annotating its information with run
results
"""
env = RunEnv(entry)
env.run()
return entry
def format_python_result(entry, verbose=False):
"""
Colorful formatting of a code snippet annotated with run results
"""
if "exception" in entry:
color = "red"
ok = False
elif entry["warnings"]:
color = "yellow"
ok = False
else:
color = "green"
ok = True
cprint(f"{entry['src']}: {entry['source']}:{entry['line']}: {entry['result']}", color, attrs=["bold"])
# Turn warnings and exceptions into line annotations
line_annotations = defaultdict(list)
if "exception" in entry:
lineno, msg = entry["exception"]
line_annotations[lineno].append((msg, "red"))
if entry["warnings"]:
for w in entry["warnings"]:
line_annotations[w.lineno].append((f"{w.category}, {w.message}", "yellow"))
show_code = line_annotations or entry.get("stderr")
if verbose and entry.get("stdout"):
show_code = True
if show_code:
for idx, line in enumerate(entry["code"].splitlines(), start=1):
annotations = line_annotations.pop(idx, None)
cprint(f" {idx:2d} {line}", "red" if annotations else "grey", attrs=["bold"])
if annotations is not None:
for msg, color in annotations:
cprint(f" ↪ {msg}", color)
for lineno, annotations in line_annotations.items():
for msg, color in annotations:
for line in msg.splitlines():
cprint(f" ↪ {line}", color)
if verbose and entry.get("stdout"):
for line in entry["stdout"].splitlines():
cprint(f" O {line}", "blue")
if entry.get("stderr"):
for line in entry["stderr"].splitlines():
cprint(f" E {line}", "blue")
if not ok:
print()
def run_tests(entries, verbose=False):
"""
Run a list of collected test snippets
"""
by_lang = defaultdict(list)
for e in entries:
lang = e["lang"]
if lang == "default":
lang = "python"
by_lang[lang].append(e)
entries = by_lang["python"]
if entries:
with multiprocessing.Pool(1) as p:
entries = p.map(run_python_test, entries)
for entry in entries:
format_python_result(entry, verbose)
# Only python supported so far
def get_code_from_rst(fname):
"""
Return the code in the first code block in the given rst
"""
with open(fname, "rt") as fd:
doctree = docutils.core.publish_doctree(
fd, source_class=docutils.io.FileInput, settings_overrides={"input_encoding": "unicode"})
for node in doctree.traverse(docutils.nodes.literal_block):
# if "dballe.DB.connect" in str(node):
for subnode in node.traverse(docutils.nodes.Text):
return {
"src": fname,
"lang": "python",
"code": subnode,
"source": node.source,
"line": node.line,
}
def main():
parser = argparse.ArgumentParser(description="Run code found in documentation code blocks")
parser.add_argument("--verbose", "-v", action="store_true", help="verbose output")
parser.add_argument("--debug", action="store_true", help="debug output")
parser.add_argument("-r", "--run", action="store_true", help="run the first code snippet found in a .rst file")
parser.add_argument("code", action="store", nargs="?", default="doc/test_code.json",
help="JSON file with the collected test code")
args = parser.parse_args()
# Setup logging
FORMAT = "%(asctime)-15s %(levelname)s %(message)s"
if args.debug:
logging.basicConfig(level=logging.DEBUG, stream=sys.stderr, format=FORMAT)
elif args.verbose:
logging.basicConfig(level=logging.INFO, stream=sys.stderr, format=FORMAT)
else:
logging.basicConfig(level=logging.WARN, stream=sys.stderr, format=FORMAT)
if args.run:
entry = get_code_from_rst(args.code)
if entry is None:
raise Fail(f"No code found in {args.code}")
env = RunEnv(entry)
env.run()
format_python_result(entry, verbose=args.debug or args.verbose)
else:
with open(args.code, "rt") as fd:
entries = json.load(fd)
run_tests(entries, verbose=args.debug or args.verbose)
if __name__ == "__main__":
try:
sys.exit(main())
except Fail as e:
print(e, file=sys.stderr)
sys.exit(1)