From 8bcc98284f1449a57c98ff8aff923ba3ce32bd1e Mon Sep 17 00:00:00 2001 From: Boldi Date: Wed, 15 Nov 2023 11:51:21 +0000 Subject: [PATCH] dev: Moving from tab based rewrite buttons to tree based side panel --- zxlive/animations.py | 64 ++++++- zxlive/custom_rule.py | 9 +- zxlive/proof_actions.py | 382 --------------------------------------- zxlive/proof_panel.py | 110 ++++------- zxlive/rewrite_action.py | 220 ++++++++++++++++++++++ zxlive/rewrite_data.py | 253 ++++++++++++++++++++++++++ 6 files changed, 580 insertions(+), 458 deletions(-) delete mode 100644 zxlive/proof_actions.py create mode 100644 zxlive/rewrite_action.py create mode 100644 zxlive/rewrite_data.py diff --git a/zxlive/animations.py b/zxlive/animations.py index 6c3b776d..e11da23d 100644 --- a/zxlive/animations.py +++ b/zxlive/animations.py @@ -1,16 +1,24 @@ +from __future__ import annotations + import itertools import random -from typing import Optional, Callable +from typing import Optional, Callable, TYPE_CHECKING from PySide6.QtCore import QEasingCurve, QPointF, QAbstractAnimation, \ QParallelAnimationGroup from PySide6.QtGui import QUndoStack, QUndoCommand from pyzx.utils import vertex_is_w -from .common import VT, GraphT, pos_to_view +from .custom_rule import CustomRule +from .rewrite_data import operations +from .common import VT, GraphT, pos_to_view, ANIMATION_DURATION from .graphscene import GraphScene from .vitem import VItem, VItemAnimation, VITEM_UNSELECTED_Z, VITEM_SELECTED_Z, get_w_partner_vitem +if TYPE_CHECKING: + from .proof_panel import ProofPanel + from .rewrite_action import RewriteAction + class AnimatedUndoStack(QUndoStack): """An undo stack that can play animations between actions.""" @@ -256,3 +264,55 @@ def unfuse(before: GraphT, after: GraphT, src: VT, scene: GraphScene) -> QAbstra return morph_graph(before, after, scene, to_start=lambda _: src, to_end=lambda _: None, duration=700, ease=QEasingCurve(QEasingCurve.Type.OutElastic)) + +def make_animation(self: RewriteAction, panel: ProofPanel, g, matches, rem_verts) -> tuple: + anim_before = None + anim_after = None + if self.name == operations['spider']['text'] or self.name == operations['fuse_w']['text']: + anim_before = QParallelAnimationGroup() + for v1, v2 in matches: + if v1 in rem_verts: + v1, v2 = v2, v1 + anim_before.addAnimation(fuse(panel.graph_scene.vertex_map[v2], panel.graph_scene.vertex_map[v1])) + elif self.name == operations['to_z']['text']: + print('To do: animate ' + self.name) + elif self.name == operations['to_x']['text']: + print('To do: animate ' + self.name) + elif self.name == operations['rem_id']['text']: + anim_before = QParallelAnimationGroup() + for m in matches: + anim_before.addAnimation(remove_id(panel.graph_scene.vertex_map[m[0]])) + elif self.name == operations['copy']['text']: + anim_before = QParallelAnimationGroup() + for m in matches: + anim_before.addAnimation(fuse(panel.graph_scene.vertex_map[m[0]], + panel.graph_scene.vertex_map[m[1]])) + anim_after = QParallelAnimationGroup() + for m in matches: + anim_after.addAnimation(strong_comp(panel.graph, g, m[1], panel.graph_scene)) + elif self.name == operations['pauli']['text']: + print('To do: animate ' + self.name) + elif self.name == operations['bialgebra']['text']: + anim_before = QParallelAnimationGroup() + for v1, v2 in matches: + anim_before.addAnimation(fuse(panel.graph_scene.vertex_map[v1], + panel.graph_scene.vertex_map[v2], meet_halfway=True)) + anim_after = QParallelAnimationGroup() + for v1, v2 in matches: + v2_row, v2_qubit = panel.graph.row(v2), panel.graph.qubit(v2) + panel.graph.set_row(v2, (panel.graph.row(v1) + v2_row) / 2) + panel.graph.set_qubit(v2, (panel.graph.qubit(v1) + v2_qubit) / 2) + anim_after.addAnimation(strong_comp(panel.graph, g, v2, panel.graph_scene)) + panel.graph.set_row(v2, v2_row) + panel.graph.set_qubit(v2, v2_qubit) + elif isinstance(self.rule, CustomRule) and self.rule.last_rewrite_center is not None: + center = self.rule.last_rewrite_center + duration = ANIMATION_DURATION / 2 + anim_before = morph_graph_to_center(panel.graph, lambda v: v not in g.graph, + panel.graph_scene, center, duration, + QEasingCurve(QEasingCurve.Type.InQuad)) + anim_after = morph_graph_from_center(g, lambda v: v not in panel.graph.graph, + panel.graph_scene, center, duration, + QEasingCurve(QEasingCurve.Type.OutQuad)) + + return anim_before, anim_after diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 349a9463..daff45cd 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -16,7 +16,7 @@ from .common import ET, VT, GraphT if TYPE_CHECKING: - from .proof_actions import ProofAction + from .rewrite_data import RewriteData class CustomRule: def __init__(self, lhs_graph: GraphT, rhs_graph: GraphT, name: str, description: str) -> None: @@ -109,9 +109,10 @@ def from_json(cls, json_str: str) -> "CustomRule": assert (isinstance(lhs_graph, GraphT) and isinstance(rhs_graph, GraphT)) # type: ignore return cls(lhs_graph, rhs_graph, d['name'], d['description']) - def to_proof_action(self) -> "ProofAction": - from .proof_actions import MATCHES_VERTICES, ProofAction - return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) + def to_rewrite_data(self) -> "RewriteData": + from .rewrite_data import MATCHES_VERTICES + return {"text": self.name, "matcher": self.matcher, "rule": self, "type": MATCHES_VERTICES, + "tooltip": self.description, 'copy_first': False, 'returns_new_graph': False} def get_linear(v): diff --git a/zxlive/proof_actions.py b/zxlive/proof_actions.py deleted file mode 100644 index 46df31bc..00000000 --- a/zxlive/proof_actions.py +++ /dev/null @@ -1,382 +0,0 @@ -import copy -from dataclasses import dataclass, field, replace -from typing import Callable, Literal, Optional, TYPE_CHECKING - -import pyzx -from pyzx import simplify, extract_circuit - -from PySide6.QtWidgets import QPushButton, QButtonGroup -from PySide6.QtCore import QParallelAnimationGroup, QEasingCurve - -from . import animations as anims -from .commands import AddRewriteStep -from .common import ANIMATION_DURATION, ET, GraphT, VT -from .custom_rule import CustomRule -from .dialogs import show_error_msg - -if TYPE_CHECKING: - from .proof_panel import ProofPanel - -operations = copy.deepcopy(pyzx.editor.operations) - -MatchType = Literal[1, 2] - -# Copied from pyzx.editor_actions -MATCHES_VERTICES: MatchType = 1 -MATCHES_EDGES: MatchType = 2 - - -@dataclass -class ProofAction(object): - name: str - matcher: Callable[[GraphT, Callable], list] - rule: Callable[[GraphT, list], pyzx.rules.RewriteOutputType[ET,VT]] - match_type: MatchType - tooltip: str - copy_first: bool = field(default=False) # Whether the graph should be copied before trying to test whether it matches. Needed if the matcher changes the graph. - returns_new_graph: bool = field(default=False) # Whether the rule returns a new graph instead of returning the rewrite changes. - button: Optional[QPushButton] = field(default=None, init=False) - - @classmethod - def from_dict(cls, d: dict) -> "ProofAction": - if 'copy_first' not in d: - d['copy_first'] = False - if 'returns_new_graph' not in d: - d['returns_new_graph'] = False - return cls(d['text'], d['matcher'], d['rule'], d['type'], d['tooltip'], d['copy_first'], d['returns_new_graph']) - - def do_rewrite(self, panel: "ProofPanel") -> None: - verts, edges = panel.parse_selection() - g = copy.deepcopy(panel.graph_scene.g) - - if self.match_type == MATCHES_VERTICES: - matches = self.matcher(g, lambda v: v in verts) - else: - matches = self.matcher(g, lambda e: e in edges) - - try: - if self.returns_new_graph: - g = self.rule(g, matches) - else: - etab, rem_verts, rem_edges, check_isolated_vertices = self.rule(g, matches) - g.remove_edges(rem_edges) - g.remove_vertices(rem_verts) - g.add_edge_table(etab) - except Exception as e: - show_error_msg('Error while applying rewrite rule', str(e)) - return - - cmd = AddRewriteStep(panel.graph_view, g, panel.step_view, self.name) - anim_before = None - anim_after = None - if self.name == operations['spider']['text'] or self.name == operations['fuse_w']['text']: - anim_before = QParallelAnimationGroup() - for v1, v2 in matches: - if v1 in rem_verts: - v1, v2 = v2, v1 - anim_before.addAnimation(anims.fuse(panel.graph_scene.vertex_map[v2], panel.graph_scene.vertex_map[v1])) - elif self.name == operations['to_z']['text']: - print('To do: animate ' + self.name) - elif self.name == operations['to_x']['text']: - print('To do: animate ' + self.name) - elif self.name == operations['rem_id']['text']: - anim_before = QParallelAnimationGroup() - for m in matches: - anim_before.addAnimation(anims.remove_id(panel.graph_scene.vertex_map[m[0]])) - elif self.name == operations['copy']['text']: - anim_before = QParallelAnimationGroup() - for m in matches: - anim_before.addAnimation(anims.fuse(panel.graph_scene.vertex_map[m[0]], - panel.graph_scene.vertex_map[m[1]])) - anim_after = QParallelAnimationGroup() - for m in matches: - anim_after.addAnimation(anims.strong_comp(panel.graph, g, m[1], panel.graph_scene)) - elif self.name == operations['pauli']['text']: - print('To do: animate ' + self.name) - elif self.name == operations['bialgebra']['text']: - anim_before = QParallelAnimationGroup() - for v1, v2 in matches: - anim_before.addAnimation(anims.fuse(panel.graph_scene.vertex_map[v1], - panel.graph_scene.vertex_map[v2], meet_halfway=True)) - anim_after = QParallelAnimationGroup() - for v1, v2 in matches: - v2_row, v2_qubit = panel.graph.row(v2), panel.graph.qubit(v2) - panel.graph.set_row(v2, (panel.graph.row(v1) + v2_row) / 2) - panel.graph.set_qubit(v2, (panel.graph.qubit(v1) + v2_qubit) / 2) - anim_after.addAnimation(anims.strong_comp(panel.graph, g, v2, panel.graph_scene)) - panel.graph.set_row(v2, v2_row) - panel.graph.set_qubit(v2, v2_qubit) - elif isinstance(self.rule, CustomRule) and self.rule.last_rewrite_center is not None: - center = self.rule.last_rewrite_center - duration = ANIMATION_DURATION / 2 - anim_before = anims.morph_graph_to_center(panel.graph, lambda v: v not in g.graph, - panel.graph_scene, center, duration, - QEasingCurve(QEasingCurve.Type.InQuad)) - anim_after = anims.morph_graph_from_center(g, lambda v: v not in panel.graph.graph, - panel.graph_scene, center, duration, - QEasingCurve(QEasingCurve.Type.OutQuad)) - - panel.undo_stack.push(cmd, anim_before=anim_before, anim_after=anim_after) - - def update_active(self, g: GraphT, verts: list[VT], edges: list[ET]) -> None: - if self.copy_first: - g = copy.deepcopy(g) - if self.match_type == MATCHES_VERTICES: - matches = self.matcher(g, lambda v: v in verts) - else: - matches = self.matcher(g, lambda e: e in edges) - - if self.button is None: return - if matches: - self.button.setEnabled(True) - else: - self.button.setEnabled(False) - - -class ProofActionGroup(object): - def __init__(self, name: str, *actions: ProofAction) -> None: - self.name = name - self.actions = actions - self.btn_group: Optional[QButtonGroup] = None - self.parent_panel = None - - def copy(self) -> "ProofActionGroup": - copied_actions = [] - for action in self.actions: - action_copy = replace(action) - action_copy.button = None - copied_actions.append(action_copy) - return ProofActionGroup(self.name, *copied_actions) - - def init_buttons(self, parent: "ProofPanel") -> None: - self.btn_group = QButtonGroup(parent) - self.btn_group.setExclusive(False) - def create_rewrite(action: ProofAction, parent: "ProofPanel") -> Callable[[], None]: # Needed to prevent weird bug with closures in signals - def rewriter() -> None: - action.do_rewrite(parent) - return rewriter - for action in self.actions: - if action.button is not None: continue - btn = QPushButton(action.name, parent) - btn.setMaximumWidth(150) - btn.setStatusTip(action.tooltip) - btn.setEnabled(False) - btn.clicked.connect(create_rewrite(action, parent)) - self.btn_group.addButton(btn) - action.button = btn - - def update_active(self, g: GraphT, verts: list[VT], edges: list[ET]) -> None: - for action in self.actions: - action.update_active(g, verts, edges) - -# We want additional actions that are not part of the original PyZX editor -# So we add them to operations - -operations.update({ - "pivot_boundary": {"text": "boundary pivot", - "tooltip": "Performs a pivot between a Pauli spider and a spider on the boundary.", - "matcher": pyzx.rules.match_pivot_boundary, - "rule": pyzx.rules.pivot, - "type": MATCHES_EDGES, - "copy_first": True}, - "pivot_gadget": {"text": "gadget pivot", - "tooltip": "Performs a pivot between a Pauli spider and a spider with an arbitrary phase, creating a phase gadget.", - "matcher": pyzx.rules.match_pivot_gadget, - "rule": pyzx.rules.pivot, - "type": MATCHES_EDGES, - "copy_first": True}, - "phase_gadget_fuse": {"text": "Fuse phase gadgets", - "tooltip": "Fuses two phase gadgets with the same connectivity.", - "matcher": pyzx.rules.match_phase_gadgets, - "rule": pyzx.rules.merge_phase_gadgets, - "type": MATCHES_VERTICES, - "copy_first": True}, - "supplementarity": {"text": "Supplementarity", - "tooltip": "Looks for a pair of internal spiders with the same connectivity and supplementary angles and removes them.", - "matcher": pyzx.rules.match_supplementarity, - "rule": pyzx.rules.apply_supplementarity, - "type": MATCHES_VERTICES, - "copy_first": False}, - } -) - -always_true = lambda graph, matches: matches - -def apply_simplification(simplification: Callable[[GraphT], GraphT]) -> Callable[[GraphT, list], pyzx.rules.RewriteOutputType[ET,VT]]: - def rule(g: GraphT, matches: list) -> pyzx.rules.RewriteOutputType[ET,VT]: - simplification(g) - return ({}, [], [], True) - return rule - -def _extract_circuit(graph: GraphT, matches: list) -> GraphT: - graph.auto_detect_io() - simplify.full_reduce(graph) - return extract_circuit(graph).to_graph() - -simplifications: dict = { - 'bialg_simp': { - "text": "bialgebra", - "tooltip": "bialg_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.bialg_simp), - "type": MATCHES_VERTICES, - }, - 'spider_simp': { - "text": "spider fusion", - "tooltip": "spider_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.spider_simp), - "type": MATCHES_VERTICES, - }, - 'id_simp': { - "text": "id", - "tooltip": "id_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.id_simp), - "type": MATCHES_VERTICES, - }, - 'phase_free_simp': { - "text": "phase free", - "tooltip": "phase_free_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.phase_free_simp), - "type": MATCHES_VERTICES, - }, - 'pivot_simp': { - "text": "pivot", - "tooltip": "pivot_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.pivot_simp), - "type": MATCHES_VERTICES, - }, - 'pivot_gadget_simp': { - "text": "pivot gadget", - "tooltip": "pivot_gadget_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.pivot_gadget_simp), - "type": MATCHES_VERTICES, - }, - 'pivot_boundary_simp': { - "text": "pivot boundary", - "tooltip": "pivot_boundary_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.pivot_boundary_simp), - "type": MATCHES_VERTICES, - }, - 'gadget_simp': { - "text": "gadget", - "tooltip": "gadget_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.gadget_simp), - "type": MATCHES_VERTICES, - }, - 'lcomp_simp': { - "text": "local complementation", - "tooltip": "lcomp_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.lcomp_simp), - "type": MATCHES_VERTICES, - }, - 'clifford_simp': { - "text": "clifford simplification", - "tooltip": "clifford_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.clifford_simp), - "type": MATCHES_VERTICES, - }, - 'tcount': { - "text": "tcount", - "tooltip": "tcount", - "matcher": always_true, - "rule": apply_simplification(simplify.tcount), - "type": MATCHES_VERTICES, - }, - 'to_gh': { - "text": "to green-hadamard form", - "tooltip": "to_gh", - "matcher": always_true, - "rule": apply_simplification(simplify.to_gh), - "type": MATCHES_VERTICES, - }, - 'to_rg': { - "text": "to red-green form", - "tooltip": "to_rg", - "matcher": always_true, - "rule": apply_simplification(simplify.to_rg), - "type": MATCHES_VERTICES, - }, - 'full_reduce': { - "text": "full reduce", - "tooltip": "full_reduce", - "matcher": always_true, - "rule": apply_simplification(simplify.full_reduce), - "type": MATCHES_VERTICES, - }, - 'teleport_reduce': { - "text": "teleport reduce", - "tooltip": "teleport_reduce", - "matcher": always_true, - "rule": apply_simplification(simplify.teleport_reduce), - "type": MATCHES_VERTICES, - }, - 'reduce_scalar': { - "text": "reduce scalar", - "tooltip": "reduce_scalar", - "matcher": always_true, - "rule": apply_simplification(simplify.reduce_scalar), - "type": MATCHES_VERTICES, - }, - 'supplementarity_simp': { - "text": "supplementarity", - "tooltip": "supplementarity_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.supplementarity_simp), - "type": MATCHES_VERTICES, - }, - 'to_clifford_normal_form_graph': { - "text": "to clifford normal form", - "tooltip": "to_clifford_normal_form_graph", - "matcher": always_true, - "rule": apply_simplification(simplify.to_clifford_normal_form_graph), - "type": MATCHES_VERTICES, - }, - 'extract_circuit': { - "text": "circuit extraction", - "tooltip": "extract_circuit", - "matcher": always_true, - "rule": _extract_circuit, - "type": MATCHES_VERTICES, - "returns_new_graph": True, - }, -} - - -spider_fuse = ProofAction.from_dict(operations['spider']) -to_z = ProofAction.from_dict(operations['to_z']) -to_x = ProofAction.from_dict(operations['to_x']) -rem_id = ProofAction.from_dict(operations['rem_id']) -copy_action = ProofAction.from_dict(operations['copy']) -pauli = ProofAction.from_dict(operations['pauli']) -bialgebra = ProofAction.from_dict(operations['bialgebra']) -euler_rule = ProofAction.from_dict(operations['euler']) -rules_basic = ProofActionGroup("Basic rules", spider_fuse, to_z, to_x, rem_id, copy_action, pauli, bialgebra, euler_rule).copy() - -lcomp = ProofAction.from_dict(operations['lcomp']) -pivot = ProofAction.from_dict(operations['pivot']) -pivot_boundary = ProofAction.from_dict(operations['pivot_boundary']) -pivot_gadget = ProofAction.from_dict(operations['pivot_gadget']) -supplementarity = ProofAction.from_dict(operations['supplementarity']) -rules_graph_theoretic = ProofActionGroup("Graph-like rules", lcomp, pivot, pivot_boundary, pivot_gadget, supplementarity).copy() - -w_fuse = ProofAction.from_dict(operations['fuse_w']) -z_to_z_box = ProofAction.from_dict(operations['z_to_z_box']) -rules_zxw = ProofActionGroup("ZXW rules",spider_fuse, w_fuse, z_to_z_box).copy() - -hbox_to_edge = ProofAction.from_dict(operations['had2edge']) -fuse_hbox = ProofAction.from_dict(operations['fuse_hbox']) -mult_hbox = ProofAction.from_dict(operations['mult_hbox']) -rules_zh = ProofActionGroup("ZH rules", hbox_to_edge, fuse_hbox, mult_hbox).copy() - -simplification_actions = ProofActionGroup("Simplification routines", *[ProofAction.from_dict(s) for s in simplifications.values()]).copy() - -action_groups = [rules_basic, rules_graph_theoretic, rules_zxw, rules_zh, simplification_actions] diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index 1697257d..d502e44c 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -1,30 +1,26 @@ from __future__ import annotations import copy -import os -from fractions import Fraction from typing import Iterator, Union, cast import pyzx from PySide6.QtCore import (QItemSelection, QModelIndex, QPersistentModelIndex, QPointF, QRect, QSize, Qt) from PySide6.QtGui import (QAction, QColor, QFont, QFontMetrics, QIcon, - QPainter, QPen, QVector2D) -from PySide6.QtWidgets import (QAbstractItemView, QHBoxLayout, QListView, + QPainter, QPen, QVector2D, QFontInfo) +from PySide6.QtWidgets import (QAbstractItemView, QListView, QStyle, QStyledItemDelegate, - QStyleOptionViewItem, QToolButton, QWidget, - QVBoxLayout, QTabWidget, QInputDialog) + QStyleOptionViewItem, QToolButton, + QInputDialog, QTreeView) from pyzx import VertexType, basicrules from pyzx.graph.jsonparser import string_to_phase from pyzx.utils import get_z_box_label, set_z_box_label, get_w_partner, EdgeType, FractionLike from . import animations as anims -from . import proof_actions from .base_panel import BasePanel, ToolbarSection from .commands import AddRewriteStep, GoToRewriteStep, MoveNodeInStep -from .common import (get_custom_rules_path, ET, SCALE, VT, GraphT, get_data, +from .common import (ET, VT, GraphT, get_data, pos_from_view, pos_to_view, colors) -from .custom_rule import CustomRule from .dialogs import show_error_msg from .eitem import EItem from .graphscene import GraphScene @@ -32,6 +28,8 @@ from .proof import ProofModel from .vitem import DragState, VItem, W_INPUT_OFFSET, SCALE from .editor_base_panel import string_to_complex +from .rewrite_data import action_groups, refresh_custom_rules +from .rewrite_action import RewriteActionTreeModel class ProofPanel(BasePanel): @@ -41,8 +39,6 @@ def __init__(self, graph: GraphT, *actions: QAction) -> None: super().__init__(*actions) self.graph_scene = GraphScene() self.graph_scene.vertices_moved.connect(self._vert_moved) - # TODO: Right now this calls for every single vertex selected, even if we select many at the same time - self.graph_scene.selectionChanged.connect(self.update_on_selection) self.graph_scene.vertex_double_clicked.connect(self._vert_double_clicked) @@ -50,10 +46,9 @@ def __init__(self, graph: GraphT, *actions: QAction) -> None: self.splitter.addWidget(self.graph_view) self.graph_view.set_graph(graph) - self.actions_bar = QTabWidget(self) - self.layout().insertWidget(1, self.actions_bar) # type: ignore - self.init_action_groups() - self.actions_bar.currentChanged.connect(self.update_on_selection) + self.rewrites_panel = QTreeView(self) + self.splitter.insertWidget(0, self.rewrites_panel) + self.init_rewrites_bar() self.graph_view.wand_trace_finished.connect(self._wand_trace_finished) self.graph_scene.vertex_dragged.connect(self._vertex_dragged) @@ -104,35 +99,29 @@ def _toolbar_sections(self) -> Iterator[ToolbarSection]: self.refresh_rules = QToolButton(self) self.refresh_rules.setText("Refresh rules") - self.refresh_rules.clicked.connect(self._refresh_rules) + self.refresh_rules.clicked.connect(self._refresh_rewrites_model) yield ToolbarSection(*self.identity_choice, exclusive=True) yield ToolbarSection(*self.actions()) yield ToolbarSection(self.refresh_rules) - def init_action_groups(self) -> None: - self.action_groups = [group.copy() for group in proof_actions.action_groups] - custom_rules = [] - for root, dirs, files in os.walk(get_custom_rules_path()): - for file in files: - if file.endswith(".zxr"): - zxr_file = os.path.join(root, file) - with open(zxr_file, "r") as f: - rule = CustomRule.from_json(f.read()).to_proof_action() - custom_rules.append(rule) - self.action_groups.append(proof_actions.ProofActionGroup("Custom rules", *custom_rules).copy()) - for group in self.action_groups: - hlayout = QHBoxLayout() - group.init_buttons(self) - for action in group.actions: - assert action.button is not None - hlayout.addWidget(action.button) - hlayout.addStretch() - - widget = QWidget() - widget.setLayout(hlayout) - setattr(widget, "action_group", group) - self.actions_bar.addTab(widget, group.name) + def init_rewrites_bar(self) -> None: + self.rewrites_panel.setUniformRowHeights(True) + self.rewrites_panel.setSelectionMode(QAbstractItemView.SelectionMode.NoSelection) + fi = QFontInfo(self.font()) + + self.rewrites_panel.setStyleSheet( + f''' + QTreeView::Item:hover {{ + background-color: #e2f4ff; + }} + QTreeView::Item{{ + height:{fi.pixelSize() * 2}px; + }} + ''') + + # Set the models + self._refresh_rewrites_model() def parse_selection(self) -> tuple[list[VT], list[ET]]: selection = list(self.graph_scene.selected_vertices) @@ -145,12 +134,6 @@ def parse_selection(self) -> tuple[list[VT], list[ET]]: return selection, edges - def update_on_selection(self) -> None: - selection, edges = self.parse_selection() - g = self.graph_scene.g - action_group = getattr(self.actions_bar.currentWidget(), "action_group") - action_group.update_active(g, selection, edges) - def _vert_moved(self, vs: list[tuple[VT, float, float]]) -> None: cmd = MoveNodeInStep(self.graph_view, vs, self.step_view) self.undo_stack.push(cmd) @@ -232,7 +215,7 @@ def cross(a: QPointF, b: QPointF) -> float: if not trace.shift and basicrules.check_remove_id(self.graph, vertex): self._remove_id(vertex) return True - + if trace.shift and self.graph.type(vertex) != VertexType.W_OUTPUT: phase_is_complex = (self.graph.type(vertex) == VertexType.Z_BOX) if phase_is_complex: @@ -254,7 +237,7 @@ def cross(a: QPointF, b: QPointF) -> float: phase = get_z_box_label(self.graph, vertex) else: phase = self.graph.phase(vertex) - + start = trace.hit[item][0] end = trace.hit[item][-1] if start.y() > end.y(): @@ -272,7 +255,7 @@ def cross(a: QPointF, b: QPointF) -> float: else: right.append(neighbor) mouse_dir = ((start + end) * (1/2)) - pos - + if self.graph.type(vertex) == VertexType.W_OUTPUT: self._unfuse_w(vertex, left, mouse_dir) else: @@ -291,7 +274,7 @@ def _unfuse_w(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF) -> Non vi = get_w_partner(self.graph, v) par_dir = QVector2D( - self.graph.row(v) - self.graph.row(vi), + self.graph.row(v) - self.graph.row(vi), self.graph.qubit(v) - self.graph.qubit(vi) ).normalized() @@ -408,27 +391,14 @@ def _proof_step_selected(self, selected: QItemSelection, deselected: QItemSelect cmd = GoToRewriteStep(self.graph_view, self.step_view, deselected.first().topLeft().row(), selected.first().topLeft().row()) self.undo_stack.push(cmd) - def _refresh_rules(self) -> None: - self.actions_bar.removeTab(self.actions_bar.count() - 1) - custom_rules = [] - for root, dirs, files in os.walk(get_custom_rules_path()): - for file in files: - if file.endswith(".zxr"): - zxr_file = os.path.join(root, file) - with open(zxr_file, "r") as f: - rule = CustomRule.from_json(f.read()).to_proof_action() - custom_rules.append(rule) - group = proof_actions.ProofActionGroup("Custom rules", *custom_rules).copy() - hlayout = QHBoxLayout() - group.init_buttons(self) - for action in group.actions: - assert action.button is not None - hlayout.addWidget(action.button) - hlayout.addStretch() - widget = QWidget() - widget.setLayout(hlayout) - setattr(widget, "action_group", group) - self.actions_bar.addTab(widget, group.name) + def _refresh_rewrites_model(self) -> None: + refresh_custom_rules() + model = RewriteActionTreeModel.from_dict(action_groups, self) + self.rewrites_panel.setModel(model) + self.rewrites_panel.clicked.connect(model.do_rewrite) + # TODO: Right now this calls for every single vertex selected, even if we select many at the same time + self.graph_scene.selectionChanged.connect(model.update_on_selection) + self.rewrites_panel.expandAll() class ProofStepItemDelegate(QStyledItemDelegate): diff --git a/zxlive/rewrite_action.py b/zxlive/rewrite_action.py new file mode 100644 index 00000000..1d2b01b4 --- /dev/null +++ b/zxlive/rewrite_action.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import copy +from dataclasses import dataclass, field +from typing import Callable, TYPE_CHECKING + +import pyzx +from PySide6.QtCore import Qt, QAbstractItemModel, QModelIndex + +from .animations import make_animation +from .commands import AddRewriteStep +from .common import ET, GraphT, VT +from .dialogs import show_error_msg +from .rewrite_data import is_rewrite_data, RewriteData, MatchType, MATCHES_VERTICES + +if TYPE_CHECKING: + from .proof_panel import ProofPanel + +operations = copy.deepcopy(pyzx.editor.operations) + + +@dataclass +class RewriteAction: + name: str + matcher: Callable[[GraphT, Callable], list] + rule: Callable[[GraphT, list], pyzx.rules.RewriteOutputType[ET, VT]] + match_type: MatchType + tooltip: str + # Whether the graph should be copied before trying to test whether it matches. + # Needed if the matcher changes the graph. + copy_first: bool = field(default=False) + # Whether the rule returns a new graph instead of returning the rewrite changes. + returns_new_graph: bool = field(default=False) + enabled: bool = field(default=False) + + @classmethod + def from_rewrite_data(cls, d: RewriteData) -> RewriteAction: + return cls( + name=d['text'], + matcher=d['matcher'], + rule=d['rule'], + match_type=d['type'], + tooltip=d['tooltip'], + copy_first=d.get('copy_first', False), + returns_new_graph=d.get('returns_new_graph', False), + ) + + def do_rewrite(self, panel: ProofPanel) -> None: + if not self.enabled: + return + + g = copy.deepcopy(panel.graph_scene.g) + verts, edges = panel.parse_selection() + + matches = self.matcher(g, lambda v: v in verts) \ + if self.match_type == MATCHES_VERTICES \ + else self.matcher(g, lambda e: e in edges) + + try: + g, rem_verts = self.apply_rewrite(g, matches) + except Exception as e: + show_error_msg('Error while applying rewrite rule', str(e)) + return + + cmd = AddRewriteStep(panel.graph_view, g, panel.step_view, self.name) + anim_before, anim_after = make_animation(self, panel, g, matches, rem_verts) + panel.undo_stack.push(cmd, anim_before=anim_before, anim_after=anim_after) + + def apply_rewrite(self, g: GraphT, matches: list): + if self.returns_new_graph: + return self.rule(g, matches), None + + etab, rem_verts, rem_edges, check_isolated_vertices = self.rule(g, matches) + g.remove_edges(rem_edges) + g.remove_vertices(rem_verts) + g.add_edge_table(etab) + return g, rem_verts + + def update_active(self, g: GraphT, verts: list[VT], edges: list[ET]) -> None: + if self.copy_first: + g = copy.deepcopy(g) + self.enabled = bool( + self.matcher(g, lambda v: v in verts) + if self.match_type == MATCHES_VERTICES + else self.matcher(g, lambda e: e in edges) + ) + + +@dataclass +class RewriteActionTree: + id: str + rewrite: RewriteAction | None + child_items: list[RewriteActionTree] + parent: RewriteActionTree | None + + @property + def is_rewrite(self) -> bool: + return self.rewrite is not None + + @property + def rewrite_action(self) -> RewriteAction: + assert self.rewrite is not None + return self.rewrite + + def append_child(self, child: RewriteActionTree) -> None: + self.child_items.append(child) + + def child(self, row: int) -> RewriteActionTree: + assert -len(self.child_items) <= row < len(self.child_items) + return self.child_items[row] + + def child_count(self) -> int: + return len(self.child_items) + + def row(self) -> int | None: + return self.parent.child_items.index(self) if self.parent else None + + def header(self) -> str: + return self.id if self.rewrite is None else self.rewrite.name + + def tooltip(self) -> str: + return "" if self.rewrite is None else self.rewrite.tooltip + + def enabled(self) -> bool: + return self.rewrite is None or self.rewrite.enabled + + @classmethod + def from_dict(cls, d: dict, header: str = "", parent: RewriteActionTree | None = None) -> RewriteActionTree: + if is_rewrite_data(d): + return RewriteActionTree( + header, RewriteAction.from_rewrite_data(d), [], parent + ) + ret = RewriteActionTree(header, None, [], parent) + for group, actions in d.items(): + ret.append_child(cls.from_dict(actions, group, ret)) + return ret + + def update_on_selection(self, g, selection, edges): + for child in self.child_items: + child.update_on_selection(g, selection, edges) + if self.rewrite is not None: + self.rewrite.update_active(g, selection, edges) + + +class RewriteActionTreeModel(QAbstractItemModel): + root_item: RewriteActionTree + + def __init__(self, data: RewriteActionTree, proof_panel: ProofPanel): + super().__init__(proof_panel) + self.proof_panel = proof_panel + self.root_item = data + + @classmethod + def from_dict(cls, d: dict, proof_panel: ProofPanel): + return RewriteActionTreeModel( + RewriteActionTree.from_dict(d), + proof_panel + ) + + def index(self, row: int, column: int, parent: QModelIndex = None) -> QModelIndex: + if not self.hasIndex(row, column, parent): + return QModelIndex() + + parentItem = parent.internalPointer() if parent.isValid() else self.root_item + + if childItem := parentItem.child(row): + return self.createIndex(row, column, childItem) + return QModelIndex() + + def parent(self, index: QModelIndex = None) -> QModelIndex: + if not index.isValid(): + return QModelIndex() + + parentItem = index.internalPointer().parent + + if parentItem == self.root_item: + return QModelIndex() + + return self.createIndex(parentItem.row(), 0, parentItem) + + def rowCount(self, parent: QModelIndex = None) -> int: + if parent.column() > 0: + return 0 + parentItem = parent.internalPointer() if parent.isValid() else self.root_item + return parentItem.child_count() + + def columnCount(self, parent: QModelIndex = None) -> int: + return 1 + + def flags(self, index: QModelIndex) -> Qt.ItemFlag: + if index.isValid(): + return Qt.ItemFlag.ItemIsEnabled if index.internalPointer().enabled() else Qt.ItemFlag.NoItemFlags + return Qt.ItemFlag.ItemIsEnabled + + def data(self, index: QModelIndex, role: Qt.ItemDataRole = Qt.ItemDataRole.DisplayRole) -> str: + if index.isValid() and role == Qt.ItemDataRole.DisplayRole: + return index.internalPointer().header() + if index.isValid() and role == Qt.ItemDataRole.ToolTipRole: + return index.internalPointer().tooltip() + elif not index.isValid(): + return self.root_item.header() + + def headerData(self, section: int, orientation: Qt.Orientation, + role: Qt.ItemDataRole = Qt.ItemDataRole.DisplayRole) -> str: + if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole: + return self.root_item.header() + return "" + + def do_rewrite(self, index: QModelIndex) -> None: + if not index.isValid(): + return + node = index.internalPointer() + if node.is_rewrite: + node.rewrite_action.do_rewrite(self.proof_panel) + + def update_on_selection(self) -> None: + selection, edges = self.proof_panel.parse_selection() + g = self.proof_panel.graph_scene.g + + self.root_item.update_on_selection(g, selection, edges) diff --git a/zxlive/rewrite_data.py b/zxlive/rewrite_data.py new file mode 100644 index 00000000..555c9a60 --- /dev/null +++ b/zxlive/rewrite_data.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +import copy +import os +from typing import Callable, Literal, TypedDict + +import pyzx +from pyzx import simplify, extract_circuit + +from .common import ET, GraphT, VT, get_custom_rules_path +from .custom_rule import CustomRule + +operations = copy.deepcopy(pyzx.editor.operations) + +MatchType = Literal[1, 2] + +# Copied from pyzx.editor_actions +MATCHES_VERTICES: MatchType = 1 +MATCHES_EDGES: MatchType = 2 + + +class RewriteData(TypedDict): + text: str + matcher: Callable[[GraphT, Callable], list] + rule: Callable[[GraphT, list], pyzx.rules.RewriteOutputType[ET, VT]] + type: MatchType + tooltip: str + copy_first: bool | None + returns_new_graph: bool | None + + +def is_rewrite_data(d: dict) -> bool: + proof_action_keys = {"text", "tooltip", "matcher", "rule", "type"} + return proof_action_keys.issubset(set(d.keys())) + + +def read_custom_rules() -> list[RewriteData]: + custom_rules = [] + for root, dirs, files in os.walk(get_custom_rules_path()): + for file in files: + if file.endswith(".zxr"): + zxr_file = os.path.join(root, file) + with open(zxr_file, "r") as f: + rule = CustomRule.from_json(f.read()).to_rewrite_data() + custom_rules.append(rule) + return custom_rules + + +# We want additional actions that are not part of the original PyZX editor +# So we add them to operations + +rewrites_graph_theoretic: dict[str, RewriteData] = { + "pivot_boundary": {"text": "boundary pivot", + "tooltip": "Performs a pivot between a Pauli spider and a spider on the boundary.", + "matcher": pyzx.rules.match_pivot_boundary, + "rule": pyzx.rules.pivot, + "type": MATCHES_EDGES, + "copy_first": True}, + "pivot_gadget": {"text": "gadget pivot", + "tooltip": "Performs a pivot between a Pauli spider and a spider with an arbitrary phase, creating a phase gadget.", + "matcher": pyzx.rules.match_pivot_gadget, + "rule": pyzx.rules.pivot, + "type": MATCHES_EDGES, + "copy_first": True}, + "phase_gadget_fuse": {"text": "Fuse phase gadgets", + "tooltip": "Fuses two phase gadgets with the same connectivity.", + "matcher": pyzx.rules.match_phase_gadgets, + "rule": pyzx.rules.merge_phase_gadgets, + "type": MATCHES_VERTICES, + "copy_first": True}, + "supplementarity": {"text": "Supplementarity", + "tooltip": "Looks for a pair of internal spiders with the same connectivity and supplementary angles and removes them.", + "matcher": pyzx.rules.match_supplementarity, + "rule": pyzx.rules.apply_supplementarity, + "type": MATCHES_VERTICES, + "copy_first": False}, +} + +const_true = lambda graph, matches: matches + + +def apply_simplification(simplification: Callable[[GraphT], GraphT]) -> Callable[ + [GraphT, list], pyzx.rules.RewriteOutputType[ET, VT]]: + def rule(g: GraphT, matches: list) -> pyzx.rules.RewriteOutputType[ET, VT]: + simplification(g) + return ({}, [], [], True) + + return rule + + +def _extract_circuit(graph: GraphT, matches: list) -> GraphT: + graph.auto_detect_io() + simplify.full_reduce(graph) + return extract_circuit(graph).to_graph() + + +simplifications: dict[str, RewriteData] = { + 'bialg_simp': { + "text": "bialgebra", + "tooltip": "bialg_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.bialg_simp), + "type": MATCHES_VERTICES, + }, + 'spider_simp': { + "text": "spider fusion", + "tooltip": "spider_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.spider_simp), + "type": MATCHES_VERTICES, + }, + 'id_simp': { + "text": "id", + "tooltip": "id_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.id_simp), + "type": MATCHES_VERTICES, + }, + 'phase_free_simp': { + "text": "phase free", + "tooltip": "phase_free_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.phase_free_simp), + "type": MATCHES_VERTICES, + }, + 'pivot_simp': { + "text": "pivot", + "tooltip": "pivot_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.pivot_simp), + "type": MATCHES_VERTICES, + }, + 'pivot_gadget_simp': { + "text": "pivot gadget", + "tooltip": "pivot_gadget_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.pivot_gadget_simp), + "type": MATCHES_VERTICES, + }, + 'pivot_boundary_simp': { + "text": "pivot boundary", + "tooltip": "pivot_boundary_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.pivot_boundary_simp), + "type": MATCHES_VERTICES, + }, + 'gadget_simp': { + "text": "gadget", + "tooltip": "gadget_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.gadget_simp), + "type": MATCHES_VERTICES, + }, + 'lcomp_simp': { + "text": "local complementation", + "tooltip": "lcomp_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.lcomp_simp), + "type": MATCHES_VERTICES, + }, + 'clifford_simp': { + "text": "clifford simplification", + "tooltip": "clifford_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.clifford_simp), + "type": MATCHES_VERTICES, + }, + 'tcount': { + "text": "tcount", + "tooltip": "tcount", + "matcher": const_true, + "rule": apply_simplification(simplify.tcount), + "type": MATCHES_VERTICES, + }, + 'to_gh': { + "text": "to green-hadamard form", + "tooltip": "to_gh", + "matcher": const_true, + "rule": apply_simplification(simplify.to_gh), + "type": MATCHES_VERTICES, + }, + 'to_rg': { + "text": "to red-green form", + "tooltip": "to_rg", + "matcher": const_true, + "rule": apply_simplification(simplify.to_rg), + "type": MATCHES_VERTICES, + }, + 'full_reduce': { + "text": "full reduce", + "tooltip": "full_reduce", + "matcher": const_true, + "rule": apply_simplification(simplify.full_reduce), + "type": MATCHES_VERTICES, + }, + 'teleport_reduce': { + "text": "teleport reduce", + "tooltip": "teleport_reduce", + "matcher": const_true, + "rule": apply_simplification(simplify.teleport_reduce), + "type": MATCHES_VERTICES, + }, + 'reduce_scalar': { + "text": "reduce scalar", + "tooltip": "reduce_scalar", + "matcher": const_true, + "rule": apply_simplification(simplify.reduce_scalar), + "type": MATCHES_VERTICES, + }, + 'supplementarity_simp': { + "text": "supplementarity", + "tooltip": "supplementarity_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.supplementarity_simp), + "type": MATCHES_VERTICES, + }, + 'to_clifford_normal_form_graph': { + "text": "to clifford normal form", + "tooltip": "to_clifford_normal_form_graph", + "matcher": const_true, + "rule": apply_simplification(simplify.to_clifford_normal_form_graph), + "type": MATCHES_VERTICES, + }, + 'extract_circuit': { + "text": "circuit extraction", + "tooltip": "extract_circuit", + "matcher": const_true, + "rule": _extract_circuit, + "type": MATCHES_VERTICES, + "returns_new_graph": True, + }, +} + +rules_basic = {"spider", "to_z", "to_x", "rem_id", "copy", "pauli", "bialgebra", "euler"} + +rules_zxw = {"spider", "fuse_w", "z_to_z_box"} + +rules_zh = {"had2edge", "fuse_hbox", "mult_hbox"} + +action_groups = { + "Basic rules": {key: operations[key] for key in rules_basic}, + "Graph-like rules": rewrites_graph_theoretic, + "ZXW rules": {key: operations[key] for key in rules_zxw}, + "ZH rules": {key: operations[key] for key in rules_zh}, + "Simplification routines": simplifications, +} + + +def refresh_custom_rules() -> None: + action_groups["Custom rules"] = {rule["text"]: rule for rule in read_custom_rules()} + + +refresh_custom_rules()