-
Notifications
You must be signed in to change notification settings - Fork 0
/
binary_tree_grammar.py
104 lines (80 loc) · 3.09 KB
/
binary_tree_grammar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from numpy.random import uniform
import json
import sys
import os
class Tree():
def __init__(self, left = None, right = None, label = None, depth = None):
self.depth = depth
self.left = left
self.right = right
self.label = label
# used to assign unique labels, as all recursive calls refer to the same object
class Pointer:
def __init__(self, val):
self.val = val
def increment_string(s):
if not s:
return 'a'
reversed_s = s[::-1]
carry = 1
result = []
for char in reversed_s:
if carry:
if char == 'z':
result.append('a')
else:
result.append(chr(ord(char) + 1))
carry = 0
else:
result.append(char)
if carry:
result.append('a')
# wrap around to a
if (len(result) > 1):
return 'a'
return ''.join(result[::-1])
# generates a binary tree with exactly the specified depth and given branching probability
# on average, the right hand side of the tree with respect to the root tends to be deeper
# but this should not matter as the production rules are permutation invariant
def generate_tree(depth, prob, index, root = True):
curr_index = int(index.val)
index.val += 1
if (depth == 0):
return Tree(label = curr_index, depth = 0)
if (not root):
sample = uniform()
if (sample > 1-prob):
left = generate_tree(depth - 1, prob, index, False)
right = generate_tree(depth - 1, prob, index, False)
return Tree(left, right, curr_index, max(left.depth, right.depth) + 1)
else:
return Tree(label = curr_index, depth = 0)
left = generate_tree(depth - 1, prob, index, False)
if (left.depth == depth - 1):
right = generate_tree(depth - 1, prob, index, False)
else:
right = generate_tree(depth - 1, prob, index, True)
return Tree(left, right, curr_index, max(left.depth, right.depth) + 1)
def parse_tree(tree, rules, terminator, symbols):
curr_symbol = str(terminator.val)
terminator.val = increment_string(terminator.val)
if (tree.left is None):
entry = {"From": tree.label, "To": [curr_symbol]}
symbols.append(curr_symbol)
rules.append(entry)
return
entry = {"From": tree.label, "To": [[tree.left.label, tree.right.label], curr_symbol]}
symbols.append(curr_symbol)
rules.append(entry)
parse_tree(tree.left, rules, terminator, symbols)
parse_tree(tree.right, rules, terminator, symbols)
def create_grammar(depth = 3, prob = 1):
tree = generate_tree(depth, prob, Pointer(0), True)
rules = []
symbols = []
parse_tree(tree, rules, Pointer('a'), symbols)
return {"start_symbol": tree.label, "rules": rules, "symbols": list(set(symbols))}
if __name__ == "__main__":
grammar = create_grammar(8, 0.75)
with open(os.path.join("grammars/binary_tree", "test.json"), "w") as f:
json.dump(grammar, f, indent=4)