-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DODO-GPT #343
base: master
Are you sure you want to change the base?
Add DODO-GPT #343
Changes from 1 commit
5d9f41b
49d1364
597161d
9062ee0
5560e81
b56a00c
1ad2b75
c8c7601
3566a25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
# DODO-GPT is brought to you by Dave the Dodo of Picturing Quantum Processes fame | ||
|
||
import pyzx as zx | ||
import openai | ||
import sounddevice as sd | ||
import os | ||
import numpy as np | ||
import base64 | ||
import requests | ||
from pydub import AudioSegment | ||
from .utils import GET_MODULE_PATH | ||
from enum import IntEnum | ||
|
||
API_KEY = 'no-key' | ||
|
||
# Need to add error handling still! (e.g. if invalid key, or if no connection to chat-gpt server, etc.) | ||
# Maybe add a config (i.e. to set the api_key, the gpt model, the OpenAI Whisper voice, etc.) | ||
# Maybe also record your full transripts with DODO-GPT (and have the option of whether to save it to file when you close zxlive) | ||
|
||
#TEMP... (# TEMP - this should just be calling xz.utils.VertexType instead (not sure why that doesn't work though?): VertexType(1).name) | ||
class VertexType(IntEnum): | ||
"""Type of a vertex in the graph.""" | ||
BOUNDARY = 0 | ||
Z = 1 | ||
X = 2 | ||
H_BOX = 3 | ||
W_INPUT = 4 | ||
W_OUTPUT = 5 | ||
Z_BOX = 6 | ||
|
||
#TEMP... (# TEMP - this should just be calling xz.utils.VertexType instead (not sure why that doesn't work though?): VertexType(1).name) | ||
class EdgeType(IntEnum): | ||
"""Type of an edge in the graph.""" | ||
SIMPLE = 1 | ||
HADAMARD = 2 | ||
W_IO = 3 | ||
|
||
def get_local_api_key(): | ||
"""Get the API key from key.txt file""" | ||
f = open(GET_MODULE_PATH()+"/user/key.txt", "r") | ||
global API_KEY | ||
API_KEY = f.read() | ||
f.close() | ||
|
||
def query_chatgpt(prompt,model="gpt-3.5-turbo"): | ||
"""Send a prompt to Chat-GPT and return its response.""" | ||
client = openai.OpenAI( | ||
api_key = API_KEY | ||
) | ||
|
||
chat_completion = client.chat.completions.create( | ||
messages=[ | ||
{ | ||
"role":"user", | ||
"content":prompt | ||
} | ||
], | ||
model=model#"gpt-3.5-turbo"#"gpt-4o" | ||
) | ||
return chat_completion.choices[0].message.content | ||
|
||
def prep_hint(g): | ||
"""Generate the default/generic 'ask for hint' type prompt for DODO-GPT.""" | ||
strPrime = describe_graph(g) | ||
strQuery = strPrime + "Please advise me as to what ONE simplification step I should take to help immediately simplify this ZX-diagram. Please be specific to this case and not give general simplification tips. Please also keep your answer simple and do not write any diagrams or data in return." | ||
return strQuery | ||
|
||
def describe_graph(g): | ||
"""Prime DODO-GPT before making a query. Returns a string for describing to DODO-GPT the current ZX-diagram.""" | ||
|
||
VertexType = {0:'BOUNDARY',1:'Z',2:'X'} | ||
EdgeType = {1:'SIMPLE',2:'HADAMARD'} | ||
|
||
strPrime = "\nConsider a ZX-diagram defined by the following spiders:\n\nlabel, type, phase\n" | ||
for v in g.vertices(): strPrime += str(v) + ', ' + str(VertexType[g.type(v)]) + ', ' + str(g.phase(v)) + '\n' | ||
|
||
strPrime += "\nwith the following edges:\n\nsource,target,type\n" | ||
for e in g.edges(): strPrime += str(e[0]) + ', ' + str(e[1]) + ', ' + str(EdgeType[g.edge_type(e)]) + '\n' | ||
strPrime += '\n' | ||
|
||
#Follows the format... | ||
# | ||
#""" | ||
#Consider a ZX-diagram defined by the following spiders: | ||
# | ||
#label, type, phase | ||
#0, Z, 0.25 | ||
#1, Z, 0.5 | ||
#2, X, 0.5 | ||
# | ||
#with the following edges: | ||
# | ||
#source,target,type | ||
#0, 1, SIMPLE | ||
#1, 2, HADAMARD | ||
# | ||
#""" | ||
|
||
return strPrime | ||
|
||
def text_to_speech(text): | ||
"""Generates an mp3 file reading the given text (via OpenAI Whisper).""" | ||
client = openai.OpenAI( | ||
api_key = API_KEY | ||
) | ||
|
||
response = client.audio.speech.create( | ||
model="tts-1", | ||
voice="nova", | ||
input=text, | ||
) | ||
|
||
response.stream_to_file(GET_MODULE_PATH() + "/temp/Dodo_Dave_latest.mp3") | ||
os.system(GET_MODULE_PATH() + "/temp/Dodo_Dave_latest.mp3") #TEMP/TODO - THIS SHOULD BE USING A PROPER IN-APP AUDIO PLAYER RATHER THAN OS | ||
|
||
def record_audio(duration=5, sample_rate=44100): | ||
recording = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=2, dtype='int16') | ||
sd.wait() | ||
return recording | ||
|
||
def save_as_mp3(audio_data, sample_rate=44100): | ||
file_path = GET_MODULE_PATH() + "/temp/user_query_latest.mp3" | ||
audio_segment = AudioSegment( | ||
data=np.array(audio_data).tobytes(), | ||
sample_width=2, | ||
frame_rate=sample_rate, | ||
channels=2 | ||
) | ||
audio_segment.export(file_path, format='mp3') | ||
return file_path | ||
|
||
def transcribe_audio(file_path): | ||
client = openai.OpenAI(api_key=API_KEY) | ||
with open(file_path, "rb") as audio_file: | ||
transcription = client.audio.transcriptions.create( | ||
model="whisper-1", | ||
file=audio_file, | ||
language='en' | ||
) | ||
#print(f'Transcription: {transcription.text}') #TEMP | ||
return transcription.text | ||
|
||
def speech_to_text(): | ||
sample_rate = 44100 # Sample rate in Hz | ||
duration = 5 # Duration of recording in seconds | ||
audio_data = record_audio(duration, sample_rate) | ||
file_path = save_as_mp3(audio_data, sample_rate) | ||
txt = transcribe_audio(file_path) | ||
return txt | ||
|
||
def encode_image(image_path): | ||
with open(image_path, "rb") as image_file: | ||
return base64.b64encode(image_file.read()).decode('utf-8') | ||
|
||
def get_image_prompt(): | ||
"""Returns the prompt that encourages DODO-GPT to describe the given image like a ZX-diagram.""" | ||
|
||
return """ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe put this large string at the top of the file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! |
||
Please convert this image into a ZX-diagram. | ||
|
||
Then please provide a csv that lists of the spiders of this ZX-diagram, given the column headers: | ||
index,type,phase,x-pos,y-pos | ||
|
||
The type here should be given as either 'Z' or 'X' (ignore boundary spiders). The indexing should start from 0. And the phases should be written in terms of pi. x-pos and y-pos should respectively refer to their horizontal and vertical positions in the image, normalized from 0,0 (top-left) to 1,1 (bottom-right). | ||
|
||
Then please provide a csv that lists the edges of this ZX-diagram, given the column headers: | ||
source,target,type | ||
|
||
The type here should be given as 1 for a normal (i.e. black) edge and 2 for a Hadamard (i.e. blue) edge, and the sources and targets should refer to the indices of the relevant spiders. Be sure to only include direct edges connecting two spiders. | ||
|
||
Please ensure the csv's are expressed with comma separators and not in a table format. | ||
""" | ||
#After that, under a clearly marked heading "HINT", please advise me as to what ONE simplification step I should take to help immediately simplify this ZX-diagram. Please be specific to this case and not give general simplification tips. | ||
#""" | ||
|
||
def image_to_text(image_path): | ||
"""Takes a ZX-diagram-like image and returns DODO-GPT's structured description of it.""" | ||
|
||
query = get_image_prompt() | ||
|
||
# Getting the base64 string | ||
base64_image = encode_image(image_path) | ||
|
||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {API_KEY}" | ||
} | ||
|
||
payload = { | ||
"model": "gpt-4o-mini", | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": query | ||
}, | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/jpeg;base64,{base64_image}" | ||
} | ||
} | ||
] | ||
} | ||
], | ||
"max_tokens": 300 | ||
} | ||
|
||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | ||
|
||
#return response.json() | ||
return response.json()['choices'][0]['message']['content'] | ||
|
||
def response_to_zx(strResponse): | ||
scale = 2.5 | ||
|
||
strResponse = strResponse | ||
strResponse = strResponse[strResponse.index('index,type,phase'):] | ||
strResponse = strResponse[strResponse.index('\n')+1:] | ||
str_csv_verts = strResponse[:strResponse.index('```')-1] | ||
strResponse = strResponse[strResponse.index('source,target,type'):] | ||
strResponse = strResponse[strResponse.index('\n')+1:] | ||
str_csv_edges = strResponse[:strResponse.index('```')-1] | ||
|
||
g = zx.Graph() | ||
|
||
for line in str_csv_verts.split('\n'): | ||
idx,ty,ph,x,y = line.split(',') | ||
g.add_vertex(qubit=float(y)*scale,row=float(x)*scale,ty=VertexType[ty],phase=ph) | ||
|
||
for line in str_csv_edges.split('\n'): | ||
source,target,ty = line.split(',') | ||
g.add_edge((int(source),int(target)),int(ty)) | ||
|
||
return g | ||
|
||
def action_dodo_hint(active_graph) -> None: | ||
"""Queries DODO-GPT for a hint as to what simplification step should be taken next.""" | ||
#print("\n\nQUERY...\n\n", prep_hint(active_graph), "\n\nANSWER...\n\n") #TEMP | ||
dodoResponse = query_chatgpt(prep_hint(active_graph)) | ||
#print(dodoResponse) #TEMP | ||
text_to_speech(dodoResponse) | ||
|
||
def action_dodo_query(active_graph) -> None: | ||
"""Records the user's voice (plus the current ZX-diagram) and prompts DODO-GPT for a response.""" | ||
doIncludeGraph = True # Whether or not to pass information about the current ZX-diagram in with the DODO-GPT query | ||
strPrime = describe_graph(active_graph) | ||
userQuery = speech_to_text() | ||
#print("\n\nQUERY...\n\n", strPrime+userQuery, "\n\nANSWER...\n\n") #TEMP | ||
dodoResponse = query_chatgpt(strPrime+userQuery) | ||
#print(dodoResponse) #TEMP | ||
text_to_speech(dodoResponse) | ||
|
||
def action_dodo_image_to_zx(path) -> None: | ||
"""Queries DODO-GPT to generate a ZX-diagram from an image.""" | ||
strResponse = image_to_text(path) | ||
#print(strResponse) #TEMP | ||
new_graph = response_to_zx(strResponse) | ||
return new_graph |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
from typing import Iterator | ||
|
||
from PySide6.QtCore import Signal, QSettings | ||
from PySide6.QtGui import QAction | ||
from PySide6.QtGui import (QAction, QIcon) | ||
from PySide6.QtWidgets import (QToolButton) | ||
from pyzx import EdgeType, VertexType, sqasm | ||
from pyzx.circuit.qasmparser import QASMParser | ||
|
@@ -18,6 +18,7 @@ | |
from .graphscene import EditGraphScene | ||
from .graphview import GraphView | ||
from .settings_dialog import input_circuit_formats | ||
from .dodo import action_dodo_hint, action_dodo_query | ||
|
||
|
||
class GraphEditPanel(EditorBasePanel): | ||
|
@@ -60,6 +61,18 @@ def _toolbar_sections(self) -> Iterator[ToolbarSection]: | |
self.start_derivation.setText("Start Derivation") | ||
self.start_derivation.clicked.connect(self._start_derivation) | ||
yield ToolbarSection(self.start_derivation) | ||
|
||
self.dodo_hint = QToolButton(self) | ||
self.dodo_hint.setIcon(QIcon(get_data("icons/dodo.png"))) | ||
self.dodo_hint.setToolTip("Dodo Hint") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although this is funny, maybe make it more descriptive? "Ask DodoGPT for suggestions on how to rewrite your diagram" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! |
||
self.dodo_hint.clicked.connect(self._dodo_hint) | ||
|
||
self.dodo_query = QToolButton(self) | ||
self.dodo_query.setIcon(QIcon(get_data("icons/mic.svg"))) | ||
self.dodo_query.setToolTip("Dodo Query") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps make it clearer that this will use input from your mic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, I can add a little note there |
||
self.dodo_query.clicked.connect(self._dodo_query) | ||
yield ToolbarSection(self.dodo_hint, self.dodo_query) | ||
|
||
|
||
def _start_derivation(self) -> None: | ||
if not self.graph_scene.g.is_well_formed(): | ||
|
@@ -107,3 +120,9 @@ def _input_circuit(self) -> None: | |
cmd = UpdateGraph(self.graph_view, new_g) | ||
self.undo_stack.push(cmd) | ||
self.graph_scene.select_vertices(new_verts) | ||
|
||
def _dodo_hint(self) -> None: | ||
action_dodo_hint(self.graph_scene.g) | ||
|
||
def _dodo_query(self) -> None: | ||
action_dodo_query(self.graph_scene.g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to have the API key be settable in the settings of ZXLive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was my plan! This is just a quick temporary solution so that I can push the changes. I guess we could also save API keys locally if we hash them first, so that way you won't have to re-enter it in every new instance of the program?