Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DODO-GPT #343

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ dependencies = [
"numpy",
"shapely",
"pyperclip",
"imageio"
"imageio",
"openai",
"pydub",
"sounddevice",
"ffmpeg-python"
]

[project.optional-dependencies]
Expand Down
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ sphinx_autodoc_typehints>=1.10
sphinx_rtd_theme>=0.4
sphinxcontrib-svg2pdfconverter>=1.2.2
myst-parser>=3.0.0
imageio
imageio
openai
pydub
sounddevice
ffmpeg-python
20 changes: 18 additions & 2 deletions zxlive/dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
class FileFormat(Enum):
"""Supported formats for importing/exporting diagrams."""

All = "zxg *.json *.qasm *.tikz *.zxp *.zxr *.gif", "All Supported Formats"
All = "zxg *.json *.qasm *.tikz *.zxp *.zxr *.gif *.png *.jpg", "All Supported Formats"
QGraph = "zxg", "QGraph" # "file extension", "format name"
QASM = "qasm", "QASM"
TikZ = "tikz", "TikZ"
Json = "json", "JSON"
ZXProof = "zxp", "ZXProof"
ZXRule = "zxr", "ZXRule"
Gif = "gif", "Gif"
PNG = "png", "PNG"
JPEG = "jpg", "JPEG"
_value_: str

def __new__(cls, *args, **kwds): # type: ignore
Expand Down Expand Up @@ -77,7 +79,6 @@ class ImportRuleOutput:
file_path: str
r: CustomRule


def show_error_msg(title: str, description: Optional[str] = None, parent: Optional[QWidget] = None) -> None:
"""Displays an error message box."""
msg = QMessageBox(parent) #Set the parent of the QMessageBox
Expand Down Expand Up @@ -258,6 +259,21 @@ def export_gif_dialog(parent: QWidget) -> Optional[str]:
if file_path_and_format is None or not file_path_and_format[0]:
return None
return file_path_and_format[0]

def import_image_dialog(parent: QWidget) -> Optional[ImportGraphOutput | ImportProofOutput | ImportRuleOutput]:
"""Shows a dialog to import a diagram from an image on disk.

Generates and returns the imported graph or `None` if the import failed."""
file_path, selected_filter = QFileDialog.getOpenFileName(
parent=parent,
caption="Select Image",
filter=FileFormat.PNG.filter
)
if selected_filter == "":
# This happens if the user clicks on cancel
return None

return file_path

def get_lemma_name_and_description(parent: MainWindow) -> tuple[Optional[str], Optional[str]]:
dialog = QDialog(parent)
Expand Down
264 changes: 264 additions & 0 deletions zxlive/dodo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# DODO-GPT is brought to you by Dave the Dodo of Picturing Quantum Processes fame

import pyzx as zx
import openai
try:
import sounddevice as sd
except:
print('Unable to import sounddevice. DODO Query will be unavailable.') #TEMP
import os
import numpy as np
import base64
import requests
from pydub import AudioSegment
from .utils import GET_MODULE_PATH
from enum import IntEnum

API_KEY = 'no-key'

# Need to add error handling still! (e.g. if invalid key, or if no connection to chat-gpt server, etc.)
# Maybe add a config (i.e. to set the api_key, the gpt model, the OpenAI Whisper voice, etc.)
# Maybe also record your full transripts with DODO-GPT (and have the option of whether to save it to file when you close zxlive)

#TEMP... (# TEMP - this should just be calling zx.utils.VertexType instead (not sure why that doesn't work though?): VertexType(1).name)
class VertexType(IntEnum):
"""Type of a vertex in the graph."""
BOUNDARY = 0
Z = 1
X = 2
H_BOX = 3
W_INPUT = 4
W_OUTPUT = 5
Z_BOX = 6

#TEMP... (# TEMP - this should just be calling zx.utils.VertexType instead (not sure why that doesn't work though?): VertexType(1).name)
class EdgeType(IntEnum):
"""Type of an edge in the graph."""
SIMPLE = 1
HADAMARD = 2
W_IO = 3

def get_image_prompt():
"""Returns the prompt that encourages DODO-GPT to describe the given image like a ZX-diagram."""

return """
Please convert this image into a ZX-diagram.

Then please provide a csv that lists of the spiders of this ZX-diagram, given the column headers:
index,type,phase,x-pos,y-pos

The type here should be given as either 'Z' or 'X' (ignore boundary spiders). The indexing should start from 0. And the phases should be written in terms of pi. x-pos and y-pos should respectively refer to their horizontal and vertical positions in the image, normalized from 0,0 (top-left) to 1,1 (bottom-right).

Then please provide a csv that lists the edges of this ZX-diagram, given the column headers:
source,target,type

The type here should be given as 1 for a normal (i.e. black) edge and 2 for a Hadamard (i.e. blue) edge, and the sources and targets should refer to the indices of the relevant spiders. Be sure to only include direct edges connecting two spiders.

Please ensure the csv's are expressed with comma separators and not in a table format.
"""
#After that, under a clearly marked heading "HINT", please advise me as to what ONE simplification step I should take to help immediately simplify this ZX-diagram. Please be specific to this case and not give general simplification tips.
#"""

def get_local_api_key():
"""Get the API key from key.txt file"""
f = open(GET_MODULE_PATH()+"/user/key.txt", "r")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to have the API key be settable in the settings of ZXLive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my plan! This is just a quick temporary solution so that I can push the changes. I guess we could also save API keys locally if we hash them first, so that way you won't have to re-enter it in every new instance of the program?

global API_KEY
API_KEY = f.read()
f.close()

def query_chatgpt(prompt,model="gpt-3.5-turbo"):
"""Send a prompt to Chat-GPT and return its response."""
client = openai.OpenAI(
api_key = API_KEY
)

chat_completion = client.chat.completions.create(
messages=[
{
"role":"user",
"content":prompt
}
],
model=model#"gpt-3.5-turbo"#"gpt-4o"
)
return chat_completion.choices[0].message.content

def prep_hint(g):
"""Generate the default/generic 'ask for hint' type prompt for DODO-GPT."""
strPrime = describe_graph(g)
strQuery = strPrime + "Please advise me as to what ONE simplification step I should take to help immediately simplify this ZX-diagram. Please be specific to this case and not give general simplification tips. Please also keep your answer simple and do not write any diagrams or data in return."
return strQuery

def describe_graph(g):
"""Prime DODO-GPT before making a query. Returns a string for describing to DODO-GPT the current ZX-diagram."""

VertexType = {0:'BOUNDARY',1:'Z',2:'X'}
EdgeType = {1:'SIMPLE',2:'HADAMARD'}

strPrime = "\nConsider a ZX-diagram defined by the following spiders:\n\nlabel, type, phase\n"
for v in g.vertices(): strPrime += str(v) + ', ' + str(VertexType[g.type(v)]) + ', ' + str(g.phase(v)) + '\n'

strPrime += "\nwith the following edges:\n\nsource,target,type\n"
for e in g.edges(): strPrime += str(e[0]) + ', ' + str(e[1]) + ', ' + str(EdgeType[g.edge_type(e)]) + '\n'
strPrime += '\n'

#Follows the format...
#
#"""
#Consider a ZX-diagram defined by the following spiders:
#
#label, type, phase
#0, Z, 0.25
#1, Z, 0.5
#2, X, 0.5
#
#with the following edges:
#
#source,target,type
#0, 1, SIMPLE
#1, 2, HADAMARD
#
#"""

return strPrime

def text_to_speech(text):
"""Generates an mp3 file reading the given text (via OpenAI Whisper)."""
client = openai.OpenAI(
api_key = API_KEY
)

response = client.audio.speech.create(
model="tts-1",
voice="nova",
input=text,
)

response.stream_to_file(GET_MODULE_PATH() + "/temp/Dodo_Dave_latest.mp3")
os.system(GET_MODULE_PATH() + "/temp/Dodo_Dave_latest.mp3") #TEMP/TODO - THIS SHOULD BE USING A PROPER IN-APP AUDIO PLAYER RATHER THAN OS

def record_audio(duration=5, sample_rate=44100):
recording = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=2, dtype='int16')
sd.wait()
return recording

def save_as_mp3(audio_data, sample_rate=44100):
file_path = GET_MODULE_PATH() + "/temp/user_query_latest.mp3"
audio_segment = AudioSegment(
data=np.array(audio_data).tobytes(),
sample_width=2,
frame_rate=sample_rate,
channels=2
)
audio_segment.export(file_path, format='mp3')
return file_path

def transcribe_audio(file_path):
client = openai.OpenAI(api_key=API_KEY)
with open(file_path, "rb") as audio_file:
transcription = client.audio.transcriptions.create(
model="whisper-1",
file=audio_file,
language='en'
)
#print(f'Transcription: {transcription.text}') #TEMP
return transcription.text

def speech_to_text():
sample_rate = 44100 # Sample rate in Hz
duration = 5 # Duration of recording in seconds
audio_data = record_audio(duration, sample_rate)
file_path = save_as_mp3(audio_data, sample_rate)
txt = transcribe_audio(file_path)
return txt

def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')

def image_to_text(image_path):
"""Takes a ZX-diagram-like image and returns DODO-GPT's structured description of it."""

query = get_image_prompt()

# Getting the base64 string
base64_image = encode_image(image_path)

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}

payload = {
"model": "gpt-4o-mini",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": query
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 300
}

response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)

#return response.json()
return response.json()['choices'][0]['message']['content']

def response_to_zx(strResponse):
scale = 2.5

strResponse = strResponse
strResponse = strResponse[strResponse.index('index,type,phase'):]
strResponse = strResponse[strResponse.index('\n')+1:]
str_csv_verts = strResponse[:strResponse.index('```')-1]
strResponse = strResponse[strResponse.index('source,target,type'):]
strResponse = strResponse[strResponse.index('\n')+1:]
str_csv_edges = strResponse[:strResponse.index('```')-1]

g = zx.Graph()

for line in str_csv_verts.split('\n'):
idx,ty,ph,x,y = line.split(',')
g.add_vertex(qubit=float(y)*scale,row=float(x)*scale,ty=VertexType[ty],phase=ph)

for line in str_csv_edges.split('\n'):
source,target,ty = line.split(',')
g.add_edge((int(source),int(target)),int(ty))

return g

def action_dodo_hint(active_graph) -> None:
"""Queries DODO-GPT for a hint as to what simplification step should be taken next."""
#print("\n\nQUERY...\n\n", prep_hint(active_graph), "\n\nANSWER...\n\n") #TEMP
dodoResponse = query_chatgpt(prep_hint(active_graph))
#print(dodoResponse) #TEMP
text_to_speech(dodoResponse)

def action_dodo_query(active_graph) -> None:
"""Records the user's voice (plus the current ZX-diagram) and prompts DODO-GPT for a response."""
doIncludeGraph = True # Whether or not to pass information about the current ZX-diagram in with the DODO-GPT query
strPrime = describe_graph(active_graph)
userQuery = speech_to_text()
#print("\n\nQUERY...\n\n", strPrime+userQuery, "\n\nANSWER...\n\n") #TEMP
dodoResponse = query_chatgpt(strPrime+userQuery)
#print(dodoResponse) #TEMP
text_to_speech(dodoResponse)

def action_dodo_image_to_zx(path) -> None:
"""Queries DODO-GPT to generate a ZX-diagram from an image."""
strResponse = image_to_text(path)
#print(strResponse) #TEMP
new_graph = response_to_zx(strResponse)
return new_graph
26 changes: 24 additions & 2 deletions zxlive/edit_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@
from typing import Iterator

from PySide6.QtCore import Signal, QSettings
from PySide6.QtGui import QAction
from PySide6.QtGui import (QAction, QIcon)
from PySide6.QtWidgets import (QToolButton)
from pyzx import EdgeType, VertexType, sqasm
from pyzx.circuit.qasmparser import QASMParser
from pyzx.symbolic import Poly

from .base_panel import ToolbarSection
from .commands import UpdateGraph
from .common import GraphT
from .common import GraphT, get_data
from .dialogs import show_error_msg, create_circuit_dialog
from .editor_base_panel import EditorBasePanel
from .graphscene import EditGraphScene
from .graphview import GraphView
from .settings_dialog import input_circuit_formats
from .dodo import action_dodo_hint, action_dodo_query


class GraphEditPanel(EditorBasePanel):
Expand Down Expand Up @@ -60,6 +61,18 @@ def _toolbar_sections(self) -> Iterator[ToolbarSection]:
self.start_derivation.setText("Start Derivation")
self.start_derivation.clicked.connect(self._start_derivation)
yield ToolbarSection(self.start_derivation)

self.dodo_hint = QToolButton(self)
self.dodo_hint.setIcon(QIcon(get_data("icons/dodo.png")))
self.dodo_hint.setToolTip("Ask DODO-GPT for suggestions on how to rewrite your diagram")
self.dodo_hint.clicked.connect(self._dodo_hint)

self.dodo_query = QToolButton(self)
self.dodo_query.setIcon(QIcon(get_data("icons/mic.svg")))
self.dodo_query.setToolTip("Ask DODO-GPT anything via mic")
self.dodo_query.clicked.connect(self._dodo_query)
yield ToolbarSection(self.dodo_hint, self.dodo_query)


def _start_derivation(self) -> None:
if not self.graph_scene.g.is_well_formed():
Expand Down Expand Up @@ -107,3 +120,12 @@ def _input_circuit(self) -> None:
cmd = UpdateGraph(self.graph_view, new_g)
self.undo_stack.push(cmd)
self.graph_scene.select_vertices(new_verts)

def _dodo_hint(self) -> None:
action_dodo_hint(self.graph_scene.g)

def _dodo_query(self) -> None:
try:
action_dodo_query(self.graph_scene.g)
except:
print("DODO Query failed. Check if API key is valid and sounddevice was properly imported.") #TEMP
Binary file added zxlive/icons/dodo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions zxlive/icons/mic.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading