-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathssmt.py
86 lines (67 loc) · 1.97 KB
/
ssmt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#
# Stochastic SMT
#
from searchtreesample import *
# Global store of addresses, constraints, and variables
addr_idx = 0
constraints = []
variables = []
def reset():
"""Reset the global store to its initial state."""
constraints = []
variables = []
addr_idx = 0
def gen_addr(name = None):
"""Generate an address w/ optional prefix NAME, which has default value "v"."""
global addr_idx
if name == None:
name = "v"
res = name + "_" + str(addr_idx)
addr_idx += 1
return res
def add_var(vartype, name = None):
"""Add a variable of the given type. We use z3py variables as random variables."""
res = vartype(gen_addr(name))
variables.append(res)
return res
def observe(f):
"""Adds the constraint F to the distribution."""
constraints.append(f)
def sample(num_samples, proc):
"""Produces NUM_SAMPLES distributed according to PROC, which contains statements
that produce random variables / constraints."""
reset()
proc()
final_formula = True if constraints == [] else And(*constraints)
return map(lambda i: search_tree_sample(variables, final_formula, 2), range(num_samples))
# Tests
def randint(a, b, name = None):
res = add_var(Int, name)
constraints.append(And(res >= a, res <= b))
return res
def flip(name = None):
res = add_var(Bool, name)
return res
def add(x, y):
z = add_var(Int)
constraints.append(z == x + y)
return z
def cond(c, t, e, restype = Int):
cv = add_var(Bool)
res = add_var(restype)
constraints.append(cv == c)
constraints.append(Implies(cv, res == t))
constraints.append(Implies(Not(cv), res == e))
return res
def test():
def model():
b = flip(name = "b")
x1 = randint(0, 10, name = "x1")
x2 = randint(0, 10, name = "x2")
x3 = add(x1, x2)
observe(Implies(b, x3 == 6))
observe(Implies(Not(b), x3 == 10))
samples = sample(100, model)
for s in samples:
print s
test()