forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnomnigraph.py
141 lines (114 loc) · 4.24 KB
/
nomnigraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import caffe2.python._import_c_extension as C
from caffe2.python import core
from caffe2.proto import caffe2_pb2
import os
from subprocess import Popen, PIPE
import errno
class NNModule(object):
def __init__(self, net=None, device_map=None):
if net is not None:
serialized_proto = None
if isinstance(net, core.Net):
serialized_proto = net.Proto().SerializeToString()
elif isinstance(net, caffe2_pb2.NetDef):
serialized_proto = net.SerializeToString()
# Distributed
if device_map is not None:
serialized_device_map = {}
for k in device_map:
serialized_device_map[k] = device_map[k].SerializeToString()
self._NNModule = C.NNModuleFromProtobufDistributed(serialized_proto,
serialized_device_map)
# Default
elif serialized_proto:
self._NNModule, self._OpList = C.NNModuleFromProtobuf(serialized_proto)
else:
raise Exception(
"NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
)
else:
self._NNModule = C.NNModule()
@property
def dataFlow(self):
return self._NNModule.dataFlow()
@property
def controlFlow(self):
return self._NNModule.getExecutionOrder()
@property
def nodes(self):
return self._NNModule.dataFlow().nodes
@property
def operators(self):
return self._NNModule.dataFlow().operators
@property
def tensors(self):
return self._NNModule.dataFlow().tensors
def createNode(self, val):
return self._NNModule.dataFlow().createNode(val)
def deleteNode(self, node):
return self._NNModule.dataFlow().deleteNode(node)
def createEdge(self, a, b):
return self._NNModule.dataFlow().createEdge(a, b)
def deleteEdge(self, a, b=None):
if b:
self._NNModule.dataFlow().deleteEdge(a, b)
else:
self._NNModule.dataFlow().deleteEdge(a)
def replaceNode(self, old_node, new_node):
return self._NNModule.dataFlow().replaceNode(old_node, new_node)
def replaceProducer(self, tensor, new_producer):
C.replaceProducer(tensor, new_producer)
def replaceAllUsesWith(self, old_tensor, new_tensor):
C.replaceAllUsesWith(old_tensor, new_tensor)
def replaceAsConsumer(self, old_consumer, new_consumer):
C.replaceAsConsumer(old_consumer, new_consumer)
def replaceSubgraph(self, subgraph, new_node, inputs, outputs):
self._NNModule.replaceSubgraph(subgraph, new_node, inputs, outputs)
def deleteSubgraph(self, subgraph):
self._NNModule.deleteSubgraph(subgraph)
def createUniqueDataNode(self, prefix="_unique"):
return self._NNModule.createUniqueDataNode(prefix)
def convertToCaffe2Proto(self, old_proto=None):
if not old_proto:
old_proto = caffe2_pb2.NetDef()
output = self._NNModule.convertToCaffe2Proto(old_proto)
new_proto = caffe2_pb2.NetDef()
new_proto.ParseFromString(output)
return new_proto
def match(self, pattern):
for n in self.dataFlow.getMutableNodes():
m = C.matchSubgraph(n, pattern)
if m:
yield m
def render(s):
s = str(s)
cmd_exists = lambda x: any(
os.access(os.path.join(path, x), os.X_OK)
for path in os.environ["PATH"].split(os.pathsep)
)
if cmd_exists("graph-easy"):
p = Popen("graph-easy", stdin=PIPE)
try:
p.stdin.write(s.encode("utf-8"))
except IOError as e:
if e.errno == errno.EPIPE or e.errno == errno.EINVAL:
pass
else:
# Raise any other error.
raise
p.stdin.close()
p.wait()
else:
print(s)
NeuralNetOperator = C.NeuralNetOperator
Operator = C.NeuralNetOperator
NeuralNetData = C.NeuralNetData
Data = C.NeuralNetData
NNSubgraph = C.NNSubgraph
NNMatchGraph = C.NNMatchGraph
Graph = C.Graph
Annotation = C.Annotation