-
Notifications
You must be signed in to change notification settings - Fork 28
/
utils.py
207 lines (159 loc) · 5.97 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import os
import json
from groq import Groq
import openai
groq_key = "###"
openai_key = "###"
openai_org = "###"
groq_client = Groq(api_key=groq_key)
open_ai_client = openai.Client(api_key=openai_key, organization=openai_org)
def extract_json_from_end(text):
try:
return extract_json_from_end_backup(text)
except:
pass
# Find the start of the JSON object
json_start = text.find("{")
if json_start == -1:
raise ValueError("No JSON object found in the text.")
# Extract text starting from the first '{'
json_text = text[json_start:]
# Remove backslashes used for escaping in LaTeX or other formats
json_text = json_text.replace("\\", "")
# Remove any extraneous text after the JSON end
ind = len(json_text) - 1
while json_text[ind] != "}":
ind -= 1
json_text = json_text[: ind + 1]
# Find the opening curly brace that matches the closing brace
ind -= 1
cnt = 1
while cnt > 0 and ind >= 0:
if json_text[ind] == "}":
cnt += 1
elif json_text[ind] == "{":
cnt -= 1
ind -= 1
# Extract the JSON portion and load it
json_text = json_text[ind + 1:]
# Attempt to load JSON
try:
jj = json.loads(json_text)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to decode JSON: {e}")
return jj
def extract_json_from_end_backup(text):
if "```json" in text:
text = text.split("```json")[1]
text = text.split("```")[0]
ind = len(text) - 1
while text[ind] != "}":
ind -= 1
text = text[: ind + 1]
ind -= 1
cnt = 1
while cnt > 0:
if text[ind] == "}":
cnt += 1
elif text[ind] == "{":
cnt -= 1
ind -= 1
# find comments in the json string (texts between "//" and "\n") and remove them
while True:
ind_comment = text.find("//")
if ind_comment == -1:
break
ind_end = text.find("\n", ind_comment)
text = text[:ind_comment] + text[ind_end + 1 :]
# convert to json format
jj = json.loads(text[ind + 1 :])
return jj
def extract_list_from_end(text):
ind = len(text) - 1
while text[ind] != "]":
ind -= 1
text = text[: ind + 1]
ind -= 1
cnt = 1
while cnt > 0:
if text[ind] == "]":
cnt += 1
elif text[ind] == "[":
cnt -= 1
ind -= 1
# convert to json format
jj = json.loads(text[ind + 1 :])
return jj
# "llama3-70b-8192"
def get_response(prompt, model="llama3-70b-8192"):
if model == "llama3-70b-8192":
client = groq_client
else:
client = open_ai_client
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=model,
)
res = chat_completion.choices[0].message.content
return res
def load_state(state_file):
with open(state_file, "r") as f:
state = json.load(f)
return state
def save_state(state, dir):
with open(dir, "w") as f:
json.dump(state, f, indent=4)
def shape_string_to_list(shape_string):
if type(shape_string) == list:
return shape_string
# convert a string like "[N, M, K, 19]" to a list like ['N', 'M', 'K', 19]
shape_string = shape_string.strip()
shape_string = shape_string[1:-1]
shape_list = shape_string.split(",")
shape_list = [x.strip() for x in shape_list]
shape_list = [int(x) if x.isdigit() else x for x in shape_list]
if len(shape_list) == 1 and shape_list[0] == "":
shape_list = []
return shape_list
def extract_equal_sign_closed(text):
ind_1 = text.find("=====")
ind_2 = text.find("=====", ind_1 + 1)
obj = text[ind_1 + 6 : ind_2].strip()
return obj
class Logger:
def __init__(self, file):
self.file = file
def log(self, text):
with open(self.file, "a") as f:
f.write(text + "\n")
def reset(self):
with open(self.file, "w") as f:
f.write("")
def create_state(parent_dir, run_dir):
# read params.json
with open(os.path.join(parent_dir, "params.json"), "r") as f:
params = json.load(f)
data = {}
for key in params:
data[key] = params[key]["value"]
del params[key]["value"]
# save the data file in the run_dir
with open(os.path.join(run_dir, "data.json"), "w") as f:
json.dump(data, f, indent=4)
# read the description
with open(os.path.join(parent_dir, "desc.txt"), "r") as f:
desc = f.read()
state = {"description": desc, "parameters": params}
return state
def get_labels(dir):
with open(os.path.join(dir, "labels.json"), "r") as f:
labels = json.load(f)
return labels
if __name__ == "__main__":
text = 'To maximize the number of successfully transmitted shows, we can introduce a new variable called "TotalTransmittedShows". This variable represents the total number of shows that are successfully transmitted.\n\nThe constraint can be formulated as follows:\n\n\\[\n\\text{{Maximize }} TotalTransmittedShows\n\\]\n\nTo model this constraint in the MILP formulation, we need to add the following to the variables list:\n\n\\{\n "TotalTransmittedShows": \\{\n "shape": [],\n "type": "integer",\n "definition": "The total number of shows transmitted"\n \\}\n\\}\n\nAnd the following auxiliary constraint:\n\n\\[\n\\forall i \\in \\text{{NumberOfShows}}, \\sum_{j=1}^{\\text{{NumberOfStations}}} \\text{{Transmitted}}[i][j] = \\text{{TotalTransmittedShows}}\n\\]\n\nThe complete output in the requested JSON format is:\n\n\\{\n "FORMULATION": "",\n "NEW VARIABLES": \\{\n "TotalTransmittedShows": \\{\n "shape": [],\n "type": "integer",\n "definition": "The total number of shows transmitted"\n \\}\n \\},\n "AUXILIARY CONSTRAINTS": [\n ""\n ]\n\\'
extract_json_from_end(text)