-
Notifications
You must be signed in to change notification settings - Fork 0
/
pgspawn.py
331 lines (283 loc) · 12 KB
/
pgspawn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
from collections import namedtuple
import logging
import os
import signal
import socket
logger = logging.getLogger(__name__)
def bimap_dict(key_f, val_f, d):
return {
key_f(k): val_f(v)
for k, v in d.items()
}
def str2sig(s):
try:
return int(s)
except ValueError:
pass
for sig in signal.Signals:
if s.upper() == sig.name:
return sig.value
raise ValueError("unknown signal '{}'".format(s))
class GraphException(Exception):
pass
class Node(namedtuple('Node', ('command', 'inputs', 'outputs', 'sockets', 'separate_group', 'signals'))):
@classmethod
def from_dict(cls, description):
unknown_keys = description.keys() - set(cls._fields)
if len(unknown_keys) > 0:
logger.warning("Unknown keys in node description dict: {}".format(unknown_keys))
return cls(
command=[str(p) for p in description['command']],
inputs=bimap_dict(int, str, description.get('inputs', {})),
outputs=bimap_dict(int, str, description.get('outputs', {})),
sockets=bimap_dict(int, str, description.get('sockets', {})),
separate_group=bool(description.get('separate_group', False)),
signals=[str2sig(str(s)) for s in description.get('signals', [])],
)
class Graph(namedtuple('Graph', ('inputs', 'outputs', 'sockets', 'nodes'))):
@classmethod
def from_dict(cls, description):
unknown_keys = description.keys() - set(cls._fields)
if len(unknown_keys) > 0:
logger.warning("Unknown keys in graph description dict: {}".format(unknown_keys))
g = cls(
inputs=bimap_dict(str, int, description.get('inputs', {})),
outputs=bimap_dict(str, int, description.get('outputs', {})),
sockets=bimap_dict(str, int, description.get('sockets', {})),
nodes=list(map(Node.from_dict, description.get('nodes', []))),
)
g.check_for_pipe_collisions()
g.check_pipe_directions()
g.check_for_fd_collisions()
g.check_sockets()
g.check_for_dead_ends()
return g
def check_for_pipe_collisions(self):
colliding = self.inputs.keys() & self.outputs.keys()
if len(colliding) > 0:
raise GraphException("Some pipes specified as both global inputs and outputs: {}".format(colliding))
def check_pipe_directions(self):
for node_id, node in enumerate(self.nodes):
for pipe_name in self.inputs.keys():
if pipe_name in node.outputs.values():
raise GraphException(
"Pipe named '{}', definded as global input, "
"is used as output in node {}.".format(
pipe_name, node_id,
)
)
for pipe_name in self.outputs.keys():
if pipe_name in node.inputs.values():
raise GraphException(
"Pipe named '{}', definded as global output, "
"is used as input in node {}.".format(
pipe_name, node_id,
)
)
def check_for_fd_collisions(self):
for node_id, node in enumerate(self.nodes):
colliding_fds = (
(
node.inputs.keys() & node.outputs.keys()
) | (
node.inputs.keys() & node.sockets.keys()
) | (
node.outputs.keys() & node.sockets.keys()
)
)
if len(colliding_fds) > 0:
raise GraphException(
"Multiple pipes/sockets specified for single fd. "
"I'm sorry, I'm afraid I can't connect that. (node {}, fds {})".format(
node_id, colliding_fds,
)
)
def check_for_dead_ends(self):
written_pipes = set(self.inputs.keys())
read_pipes = set(self.outputs.keys())
for node in self.nodes:
written_pipes.update(node.outputs.values())
read_pipes.update(node.inputs.values())
only_written = written_pipes - read_pipes
only_read = read_pipes - written_pipes
if len(only_written) > 0:
logger.warning("Some pipes are never read: {}".format(only_written))
if len(only_read) > 0:
logger.warning("Some pipes are never written: {}".format(only_read))
def check_sockets(self):
# dict socket_id -> number of uses
socket_uses = {}
for node_id, node in enumerate(self.nodes):
for socket_id in node.sockets.values():
n = socket_uses.get(socket_id, 0)
n += 1
if n > 2:
logger.warning(
"Socket name '{}' is used more than two times (node {})."
"I can take this. And you can easily make mistake.".format(
socket_id, node_id,
)
)
socket_uses[socket_id] = n
for socket_id, n in socket_uses.items():
if n == 1:
logger.warning(
"Socket name '{}' is used only one time."
"The other end will be flapping in the breeze (untill we close it).".format(
socket_id,
)
)
def apply_fd_mapping(fd_mapping):
""" Takes dict target fd -> present fd. Moves fds to match the mapping. """
def _dup_mapping(fd, new_fd):
logger.debug("fd {} duped to {}".format(fd, new_fd))
for target_fd in fd_mapping.keys():
if fd_mapping[target_fd] == fd:
fd_mapping[target_fd] = new_fd
for target_fd in fd_mapping.keys():
fd = fd_mapping[target_fd]
if fd == target_fd:
# nothing to do
logger.debug("fd {} already in place".format(fd))
continue
# if needed make target fd free
if target_fd in fd_mapping.values():
saved_fd = os.dup(target_fd)
_dup_mapping(target_fd, saved_fd)
os.dup2(fd, target_fd, inheritable=False)
_dup_mapping(fd, target_fd)
class PipeGraphSpawner:
Process = namedtuple('Process', ('command', 'signals'))
@classmethod
def from_graph(cls, graph):
spawner = cls(
inputs=graph.inputs,
outputs=graph.outputs,
)
for node in graph.nodes:
spawner.spawn(
node.command,
node.inputs, node.outputs, node.sockets,
node.separate_group, node.signals,
)
return spawner
def __init__(self, inputs={}, outputs={}, sockets={}):
self._reading_ends = {}
self._writing_ends = {}
self._socket_other_ends = {}
# collection of running subprocesses. dict pid -> Process
self._processes = {}
def register_fds(our_dict, input_dict):
for id, fd in input_dict.items():
os.set_inheritable(fd, False)
our_dict[id] = fd
register_fds(self._writing_ends, outputs)
register_fds(self._reading_ends, inputs)
register_fds(self._socket_other_ends, sockets)
def spawn(self, command, inputs, outputs, sockets, separate_group, signals):
fd_mapping = {}
fds_to_be_closed_in_parent = []
for subprocess_fd, pipe_id in inputs.items():
assert(subprocess_fd not in fd_mapping)
fd_mapping[subprocess_fd] = self._reading_end_fd(pipe_id)
for subprocess_fd, pipe_id in outputs.items():
assert(subprocess_fd not in fd_mapping)
fd_mapping[subprocess_fd] = self._writing_end_fd(pipe_id)
for subprocess_fd, socket_id in sockets.items():
assert(subprocess_fd not in fd_mapping)
fd = self._get_and_clear_socket_end(socket_id)
fd_mapping[subprocess_fd] = fd
fds_to_be_closed_in_parent.append(fd)
pid = os.fork()
if pid == 0:
# prepare fds
apply_fd_mapping(fd_mapping)
for fd in fd_mapping.keys():
os.set_inheritable(fd, True)
if separate_group:
# create new process group
logger.debug("creating new process group")
os.setpgid(0, 0)
# run target executable
os.execvp(command[0], command)
else:
assert(pid not in self._processes)
self._processes[pid] = self.Process(command=command, signals=signals)
logger.info(
"process %d spawned command=%s fd_mapping=%s",
pid, command, fd_mapping,
)
for fd in fds_to_be_closed_in_parent:
logger.debug("fd {}: closing".format(fd))
os.close(fd)
return pid
def _reading_end_fd(self, pipe_id):
if pipe_id not in self._reading_ends:
self._make_pipe(pipe_id)
return self._reading_ends[pipe_id]
def _writing_end_fd(self, pipe_id):
if pipe_id not in self._writing_ends:
self._make_pipe(pipe_id)
return self._writing_ends[pipe_id]
def _get_and_clear_socket_end(self, socket_id):
""" Behold! This method is unexpectedly unpure!
Calling this method twice will have different results.
Caller is responsible for taking care of retrieved fd. Especially she should close it after use.
"""
if socket_id in self._socket_other_ends:
fd = self._socket_other_ends[socket_id]
del self._socket_other_ends[socket_id]
return fd
else:
def getfd(sock):
fd = sock.detach()
assert(fd >= 0)
return fd
fd_a, fd_b = map(getfd, socket.socketpair())
logger.info("socket pair '{}' created, fds {} <-> {}".format(socket_id, fd_a, fd_b))
self._socket_other_ends[socket_id] = fd_b
return fd_a
def _make_pipe(self, pipe_id):
reading_end, writing_end = os.pipe()
logger.info("pipe '{}' created, fds {} -> {}".format(pipe_id, writing_end, reading_end))
assert(pipe_id not in self._writing_ends)
self._writing_ends[pipe_id] = writing_end
assert(pipe_id not in self._reading_ends)
self._reading_ends[pipe_id] = reading_end
def close_fds(self):
for fd in self._writing_ends.values():
logger.debug("fd {}: closing".format(fd))
os.close(fd)
for fd in self._reading_ends.values():
logger.debug("fd {}: closing".format(fd))
os.close(fd)
for fd in self._socket_other_ends.values():
logger.warning("fd {}: closing (unused socket end)".format(fd))
os.close(fd)
def join(self):
statusses = {}
while len(self._processes) > 0:
pid, code = os.wait()
if pid in self._processes:
status = code // 256 # extract high byte which is exit code
if status != 0:
logger.warning(
"process %d (%s) exited with unsuccessful code %d",
pid, self._processes[pid].command, status,
)
else:
logger.info(
"process %d (%s) exited with status %d",
pid, self._processes[pid].command, status,
)
del self._processes[pid]
statusses[pid] = status
else:
logger.warning("got exit status for unknown process %d", pid)
return statusses
def dispatch_signal(self, sig):
logger.debug("got %s (%d)", signal.Signals(sig).name, sig)
for pid, process in self._processes.items():
if sig in process.signals:
logger.info("killing %d (%s) with %s (%d)", pid, process.command, signal.Signals(sig).name, sig)
os.kill(pid, sig)