-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
167 lines (132 loc) · 4.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import inspect
import pdb
import argparse
import glog
import logging
import sys
import os
import shlex
from chdrft.config.base import is_python2
from chdrft.config.env import g_env, g_ipython
import numpy as np
import random
import faulthandler
import chdrft.utils.misc as cmisc
import signal, traceback
#def quit_handler(signum,frame):
# traceback.print_stack()
#signal.signal(signal.SIGQUIT,quit_handler)
if not is_python2:
from contextlib import ExitStack
import glog
import logging
class App:
def __init__(self):
self.flags = None
self.stack: ExitStack = None
self.override_flags = {}
self.setup = False
self.cache = None
if not is_python2:
self.global_context = cmisc.ExitStackWithPush()
self.env = g_env
self.env.setup(self)
def setup_jup(self, cmdline='', argv=None, **kwargs):
argv = shlex.split(cmdline)
self(force=1, argv=argv, **kwargs, keep_open_context=1)
self.setup = True
return cmisc.A(vars(self.flags))
def exit_jup(self):
self.global_context.close()
def shell(self, n=0):
g_ipython.drop_to_shell(n + 1)
def wait_or_interactive(self, n=0):
if g_ipython.in_jupyter:
self.shell(n + 1)
else:
input('Press to stop')
def __call__(self, force=False, argv=None, parser_funcs=[], keep_open_context=None):
f = inspect.currentframe().f_back
if not force and self.setup: return
if not force and not f.f_globals['__name__'] == '__main__': return
self.setup = True
import sys
if keep_open_context is None:
keep_open_context = g_ipython.in_jupyter
if 'main' not in f.f_globals and not force: return
parser = None
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--verbosity', type=str, default='ERROR')
parser.add_argument('--pdb', action='store_true')
parser.add_argument('--log_file', type=str)
parser.add_argument('--runid', type=str, default='default')
want_cache = force or ('cache' in f.f_globals and not is_python2)
cache = None
if want_cache:
from chdrft.utils.cache import cache_argparse
cache_argparse(parser)
if 'args' in f.f_globals:
args_func = f.f_globals['args']
args_func(parser)
for x in parser_funcs:
x(parser)
random.seed(0)
np.random.seed(0)
faulthandler.register(signal.SIGQUIT)
faulthandler.enable(file=sys.stderr, all_threads=True)
faulthandler.dump_traceback(file=sys.stderr, all_threads=True)
parser.add_argument('other_args', nargs=argparse.REMAINDER, default=['--'])
flags = parser.parse_args(args=argv)
if flags.other_args and flags.other_args[0] == '--':
flags.other_args = flags.other_args[1:]
self.flags = flags
for k, v in self.override_flags.items():
setattr(self.flags, k, v)
glog.setLevel(flags.verbosity)
if flags.log_file:
glog.logger.addHandler(logging.FileHandler(flags.log_file))
if 'flags' in f.f_globals:
f.f_globals['flags'] = flags
if want_cache:
from chdrft.utils.cache import FileCacheDB
self.cache = FileCacheDB.load_from_argparse(flags)
f.f_globals['cache'] = self.cache
if self.stack is not None:
self.stack.close()
main_func = f.f_globals.get('main', None)
def go():
try:
if is_python2:
main_func()
else:
if keep_open_context:
stack = ExitStack()
self.run(stack, main_func)
else:
with ExitStack() as stack:
self.run(stack, main_func)
except Exception as e:
if flags.pdb:
pdb.post_mortem()
raise
except KeyboardInterrupt:
raise
if flags.pdb:
pdb.runcall(go)
else:
go()
self.stack = None
def run(self, stack, main_func):
self.stack = stack
stack.enter_context(self.global_context)
script_name = sys.argv[0]
plog_filename = '/tmp/opa_plog_{}_{}.log'.format(
os.path.basename(script_name), self.flags.runid
)
plog_file = open(plog_filename, 'w')
stack.enter_context(plog_file)
self.plog_file = plog_file
if self.cache:
stack.enter_context(self.cache)
if main_func is not None: main_func()
app = App()