-
Notifications
You must be signed in to change notification settings - Fork 0
/
stochasticsearch.t
278 lines (253 loc) · 10.3 KB
/
stochasticsearch.t
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
require("bithelpers")
require("util")
local cc = require("circuit")
local sim = require("simulation")
local ss = {}
ss.defaultSearchSettings = {
addMass = 1,
deleteMass = 1,
inputSwapMass = 1,
lutChangeMass = 1,
totalIterations = 10000000,
iterationsBetweenRestarts = 1000000,
maxInternalNodes = 100,
minInternalNodes = 0,
beta = 1.0,
weightCorrect = 1.0,
weightCritical = 1.0,
weightSize = 1.0
}
local function evaluate(circuit, test)
local out = sim.runCircuit(circuit, test.input)
return hammingDistance(test.output, out)
end
local function sizeCost(circuit)
return #circuit.internalNodes
end
local function furthestDistanceToInput(node)
local distance = 0
if node.inputs then
for i,v in ipairs(node.inputs) do
local dist = furthestDistanceToInput(v)+1
distance = math.max(distance, dist)
end
end
return distance
end
local function criticalPathLength(circuit)
local maxPathLength = 0
for i,v in ipairs(circuit.outputs) do
local dist = furthestDistanceToInput(v)
maxPathLength = math.max(maxPathLength, dist)
end
return maxPathLength
end
local function criticalCost(circuit)
local critPathLength = criticalPathLength(circuit)
return critPathLength
end
local function errorCost(proposal, testCases, validationCases)
local testResult = 0
log.trace(#testCases, " testCases")
for _,test in ipairs(testCases) do
testResult = testResult + evaluate(proposal,test)
end
log.info("Raw test case score ", testResult)
-- Adjustment to make sure failing N tests is worse than failing N+1 validation cases
-- after passing all test cases
testResult = testResult * (#validationCases + 1)
log.info("Adjusted test case score ", testResult)
if testResult == 0 then
log.info("Proposal passed all tests, running validation")
for _,test in ipairs(validationCases) do
testResult = testResult + evaluate(proposal,test)
end
log.info("Final test case score ", testResult)
if testResult == 0 then
log.info("validation cases failed: ", testResult)
end
end
return testResult
end
local function cost(proposal, testCases, validationCases, settings)
local errCost = errorCost(proposal, testCases, validationCases)
log.info("error cost: "..errCost)
local totalCost = errCost*settings.weightCorrect + criticalCost(proposal)*settings.weightCritical + sizeCost(proposal)*settings.weightSize
log.info("total cost "..totalCost)
return totalCost, errCost
end
local function totalProposalMass(settings)
return settings.addMass + settings.deleteMass + settings.inputSwapMass + settings.lutChangeMass
end
--
local function addRewrite(original, rnd)
local newCircuit = cc.deepCopy(original)
log.trace("About to select wire")
local wire = cc.selectWire(newCircuit, math.ceil(rnd*cc.wireCount(newCircuit)))
-- TODO: should we select inputs at random upstream from parent node?
local inputs = {wire.input, cc.getGround(newCircuit), cc.getGround(newCircuit), cc.getGround(newCircuit)}
-- TODO: should this not be random?
local lutValue = math.random(math.pow(2,16))-1
log.trace("About to add LUT")
cc.addLUTNode(newCircuit, inputs, wire, lutValue)
return newCircuit
end
local function deleteRewrite(original, rnd)
local newCircuit = cc.deepCopy(original)
local nodeIndex = math.ceil(rnd*cc.internalNodeCount(newCircuit))
local node = cc.selectInternalNode(newCircuit, nodeIndex)
cc.deleteNode(newCircuit, node)
return newCircuit
end
local function inputSwapRewrite(original, rnd)
local newCircuit = cc.deepCopy(original)
local node,isOutput = cc.selectNonInputNode(newCircuit, math.ceil(rnd*cc.nonInputNodeCount(original)))
log.trace("Getting potential inputs")
local potentialInputs = cc.upstreamNodes(newCircuit,node)
log.info(#potentialInputs, " potential inputs")
log.trace("Selecting input")
local newInputIndex = math.random(#potentialInputs)
local chosenInput = potentialInputs[newInputIndex]
log.trace("Setting Input")
if isOutput then
cc.setInputOfNode(newCircuit, node, 1, chosenInput)
else
local i = math.random(4)
cc.setInputOfNode(newCircuit, node, i, chosenInput)
end
log.trace("Set")
return newCircuit
end
local function lutChangeRewrite(original, rnd)
log.trace("in lutChangeRewrite")
local newCircuit = cc.deepCopy(original)
log.trace("making index")
local index = math.ceil(rnd*cc.internalNodeCount(newCircuit))
log.trace("about to select")
local node = cc.selectInternalNode(newCircuit, index)
log.trace("selectInternalNode")
-- TODO: should this not be random?
local lutValue = math.random(math.pow(2,16))-1
cc.setLUTValue(node, lutValue)
log.trace("lutValue")
return newCircuit
end
local function createRewrite(currentCircuit, settings)
local massSum = totalProposalMass(settings)
local N = #currentCircuit.internalNodes
if N <= settings.minInternalNodes then
massSum = massSum - settings.deleteMass
end
if N >= settings.maxInternalNodes then
massSum = massSum - settings.addMass
end
if N == 0 then
massSum = massSum - settings.lutChangeMass
end
log.debug("massSum: "..massSum)
local r = math.random()*massSum
if N < settings.maxInternalNodes then
if r < settings.addMass then
log.info("addRewrite")
return addRewrite(currentCircuit, r/settings.addMass)
end
r = r - settings.addMass
end
if N > settings.minInternalNodes then
if r < settings.deleteMass then
log.info("deleteMass")
return deleteRewrite(currentCircuit, r/settings.deleteMass)
end
r = r - settings.deleteMass
end
if r < settings.inputSwapMass then
log.info("inputSwapRewrite")
return inputSwapRewrite(currentCircuit, r/settings.inputSwapMass)
end
r = r - settings.inputSwapMass
if N > 0 and r < settings.lutChangeMass then
log.info("lutChangeRewrite")
return lutChangeRewrite(currentCircuit, r/settings.lutChangeMass)
end
assert(false,"Reached what should be probability 0 case in createRewrite() with r = "..r)
end
local function acceptRewrite(rewriteCost, previousCost, settings)
log.trace("acceptRewrite")
-- Equation 5: https://raw.githubusercontent.com/StanfordPL/stoke/develop/docs/papers/cacm16.pdf
local acceptProbability = math.min(1.0, math.exp(-settings.beta*(rewriteCost-previousCost)))
log.info("acceptProbability="..acceptProbability)
return acceptProbability >= math.random()
end
function ss.stochasticSearch(initialCircuit, testSet, validationSet, settings)
log.trace("Stochastic Search")
local currentCircuit = initialCircuit
local currentCost,currentCorrectCost = cost(initialCircuit, testSet, validationSet, settings)
print("Initial correctness cost: "..currentCorrectCost)
local bestCost = currentCost
local bestIncorrectCost = currentCost
local initialCost = currentCost
local bestCircuit = currentCircuit
local correctCircuits = {}
local endCircuits = {}
for i=1,settings.totalIterations do
if ((i-1) % settings.iterationsBetweenRestarts) == 0 then
currentCircuit = initialCircuit
currentCost = cost(initialCircuit, testSet, validationSet, settings)
bestIncorrectCost = currentCost
local independentSearchCount = (i-1) / settings.iterationsBetweenRestarts
print("------------------------------")
print(" Independent Search "..independentSearchCount)
print("------------------------------")
print("Cost of initial circuit: "..currentCost)
end
local rewriteCircuit = createRewrite(currentCircuit,settings)
log.trace("Rewritten")
local rewriteCost,rewriteCorrectnessCost = cost(rewriteCircuit, testSet, validationSet, settings)
--print("rewriteCost "..rewriteCost)
if log.level == "debug" or log.level == "trace" then
cc.nodeSanityCheck(currentCircuit)
print("========")
cc.nodeSanityCheck(rewriteCircuit)
end
if acceptRewrite(rewriteCost, currentCost, settings) then
log.info("Iteration "..i.." Rewrite accepted with cost: "..rewriteCost..", correctness cost: "..rewriteCorrectnessCost)
print("Iteration "..i.." Rewrite accepted with cost: "..rewriteCost..", correctness cost: "..rewriteCorrectnessCost)
currentCost = rewriteCost
currentCorrectCost = rewriteCorrectnessCost
currentCircuit = rewriteCircuit
if currentCorrectCost == 0 and currentCost < bestCost then
print("======================= NEW BEST CIRCUIT "..i.." =========================")
print("Cost: "..currentCost)
--if log.level == "debug" or log.level == "trace" then
cc.toGraphviz(currentCircuit, "out/correct"..(#correctCircuits + 1))
--end
correctCircuits[#correctCircuits + 1] = currentCircuit
bestCost = currentCost
bestCircuit = currentCircuit
elseif currentCorrectCost == 0 and currentCost == bestCost then
log.info("----- Equivalent best circuit: "..i)
elseif currentCost < bestCost then
log.info("----- Incorrect lower cost circuit: "..i)
if currentCost < bestIncorrectCost then
bestIncorrectCost = currentCost
if log.level == "debug" or log.level == "trace" then
cc.toGraphviz(currentCircuit, "out/incorrect_cost"..bestIncorrectCost)
end
end
end
else
log.info("Rewrite rejected")
end
if i % 100000 == 0 then
print("Iteration: "..i)
end
if ((i % settings.iterationsBetweenRestarts) == 0) or i == settings.totalIterations then
if log.level == "debug" or log.level == "trace" then
cc.toGraphviz(currentCircuit, "out/lastCircuit"..(i / settings.iterationsBetweenRestarts))
end
endCircuits[#endCircuits+1] = currentCircuit
end
end
return bestCircuit, bestCost, bestCost < initialCost, correctCircuits
end
return ss