-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen.py
82 lines (65 loc) · 2.06 KB
/
gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import io
import sys
from functools import reduce
from textgenrnn import textgenrnn
from random import choice, shuffle
STREETS_FILE = "corpus/streets.txt"
ROADTYPES_FILE = "corpus/roadtypes.txt"
def build_roadtypes():
"""
Return a giant list. Picking a random element from this list should select a
street S with probability roughly equal to S's frequency distribution in STREETS_FILE.
"""
with open(ROADTYPES_FILE) as file:
lines = file.readlines()
result = []
# repeat each street type N times where N is its frequency in the file
for l in lines:
(type, freq) = l.split(",")
for _ in range(int(freq)):
result.append(type)
# probably unnecessary
shuffle(result)
return result
def build_basenames():
"""
Returns a set of basenames for roads
"""
dict = {}
with open(STREETS_FILE) as file:
for line in file:
dict[line.strip()] = True
return dict
def get_roadtype():
return choice(roadtypes)
def add_roadtype(name):
return "{} {}".format(name, get_roadtype())
def generate(temp):
"""
Wrapper that checks generated names against the base street names to avoid a direct
regurgitation of input data.
returns list
"""
is_in_dict = True
while is_in_dict:
result = textgen.generate(temperature=temp, return_as_list=True)
str = ' '.join(result)
is_in_dict = basenames.get(str, False)
return result
"""
MAIN
"""
roadtypes = build_roadtypes()
basenames = build_basenames()
# force print into utf8 mode?
sys.stdout = io.TextIOWrapper(sys.stdout.detach(), 'utf8', 'replace')
textgen = textgenrnn('textgenrnn_weights.hdf5')
print("\n")
for t in [0.3, 0.5, 1, 1.25]:
print("Temperature {}".format(t))
print("--------------------------------------------------------------------")
for _ in range(0, 5):
list = generate(t)
for l in list:
print(add_roadtype(l))
print("")