-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchange_graph.py
68 lines (59 loc) · 2.01 KB
/
change_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.fx as fx
import torch.nn.functional as F
# Actual Model
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 6, 5)
self.pool = torch.nn.MaxPool2d(2, 2)
self.conv2 = torch.nn.Conv2d(6, 16, 5)
self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
self.fc2 = torch.nn.Linear(120, 84)
self.fc3 = torch.nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x.size()
x = self.pool(F.relu(self.conv2(x)))
print(x.size())
x = torch.flatten(x, 1) # flatten all dimensions except batch
print(x.size())
x = F.relu(self.fc1(x))
print(x.size())
x = F.relu(self.fc2(x))
print(x.size())
x = self.fc3(x)
print(x.size())
return x
m = Module()
print("original version")
# symbolically trace it
gm = torch.fx.symbolic_trace(m)
# call and print the graph
gm.graph.print_tabular()
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
#TODO: size of the tensor
#TODO: differentiate btw layer and data
#TODO (if possible): add a new list to the tabular to display the size
if node.op == 'call_method':
# The target attribute is the function
# that call_function calls.
if node.target == 'size':
node.target = 'size = 16'
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
print("manipulated version")
# instantiate it
changed = transform(Module())
# symbolically trace it
gm = torch.fx.symbolic_trace(changed)
# call and print the graph
gm.graph.print_tabular()