Skip to content

Commit

Permalink
Full IR visualization and Pulse canonicalize (#376)
Browse files Browse the repository at this point in the history
* finish the field IR for uniform and runtimevec.

* update, add canonicalize for pulses

* enhance display

* comply lint

* add visualization for IRs

* merge main, and remove plotting from task_test

* fixing linter

* fixing linter

---------

Co-authored-by: Kai-Hsin Wu <[email protected]>
Co-authored-by: Phillip Weinberg <[email protected]>
Co-authored-by: Phillip Weinberg <[email protected]>
  • Loading branch information
4 people authored Aug 19, 2023
1 parent 971859c commit b13d824
Show file tree
Hide file tree
Showing 9 changed files with 605 additions and 11 deletions.
127 changes: 127 additions & 0 deletions src/bloqade/ir/control/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from .waveform import Waveform
from typing import Dict
from ..tree_print import Printer
from bokeh.plotting import figure, show
from bokeh.layouts import gridplot, row, layout
from bokeh.models.widgets import PreText
from bokeh.models import ColumnDataSource
from bloqade.visualization.ir_visualize import get_field_figure

__all__ = [
"Field",
Expand Down Expand Up @@ -49,6 +54,12 @@ def __repr__(self) -> str:
def _repr_pretty_(self, p, cycle):
Printer(p).print(self, cycle)

def _get_data(self, **assignment):
return {}

def figure(self, **assignment):
raise NotImplementedError


@dataclass
class UniformModulation(SpatialModulation):
Expand All @@ -64,6 +75,23 @@ def print_node(self):
def children(self):
return []

def _get_data(self, **assignment):
return ["uni"], ["all"]

def figure(self, **assignment):
p = figure(sizing_mode="stretch_both")
p.text(
x=[0.5],
y=[0.5],
text="Uniform",
text_algin="center",
text_baseline="middle",
)
return p

def show(self, **assignment):
show(self.figure(**assignment))


Uniform = UniformModulation()

Expand All @@ -84,6 +112,23 @@ def print_node(self):
def children(self):
return [self.name]

def figure(self, **assginment):
p = figure(sizing_mode="stretch_both")
p.text(
x=[0.5],
y=[0.5],
text=self.name,
text_algin="center",
text_baseline="middle",
)
return p

def _get_data(self, **assignment):
return [self.name], ["vec"]

def show(self, **assignment):
show(self.figure(**assignment))


@dataclass(init=False, repr=False)
class ScaledLocations(SpatialModulation):
Expand All @@ -110,6 +155,16 @@ def __str__(self):
tmp = {f"{key.value}": val for key, val in self.value.items()}
return f"ScaledLocations({str(tmp)})"

def _get_data(self, **assignments):
names = []
scls = []

for loc, scl in self.value.items():
names.append("loc[%d]" % (loc.value))
scls.append(str(scl(**assignments)))

return names, scls

def print_node(self):
return self.__str__()

Expand All @@ -122,6 +177,26 @@ def children(self):

return annotated_children

def figure(self, **assignments):
locs = []
literal_val = []
for k, v in self.value.items():
locs.append(f"loc[{k.value}]")
literal_val.append(float(v(**assignments)))

source = ColumnDataSource(data=dict(locations=locs, yvals=literal_val))

p = figure(
y_range=locs, sizing_mode="stretch_both", x_axis_label="Scale factor"
)
p.hbar(y="locations", right="yvals", source=source, height=0.4)

return p

def show(self, **assignment):
show(self.figure(**assignment))
pass


@dataclass
class Field:
Expand Down Expand Up @@ -172,3 +247,55 @@ def print_node(self):
def children(self):
# return dict with annotations
return {spatial_mod.print_node(): wf for spatial_mod, wf in self.value.items()}

def figure_old(self, **assignments):
full_figs = []
idx = 0
for spmod, wf in self.value.items():
fig_mod = spmod.figure(**assignments)
fig_wvfm = wf.figure(**assignments)

# format AST tree:
txt = wf.__repr__()
txt = "> Waveform AST:\n" + txt

txt_asgn = ""
# format assignment:
if len(assignments):
txt_asgn = "> Assignments:\n"
for key, val in assignments.items():
txt_asgn += f"{key} := {val}\n"
txt_asgn += "\n"

# Display AST tree:

header = "Ch[%d]\n" % (idx)
text_box = PreText(text=header + txt_asgn + txt, sizing_mode="stretch_both")
text_box.styles = {"overflow": "scroll", "border": "1px solid black"}

# layout channel:
fp = gridplot(
[[row(text_box, fig_mod, sizing_mode="stretch_both"), fig_wvfm]],
merge_tools=False,
sizing_mode="stretch_both",
)
fp.styles = {"border": "2px solid black"}
fp.width_policy = "max"

full_figs.append(fp)
idx += 1

full = layout(
full_figs,
# merge_tools=False,
sizing_mode="stretch_both",
)
full.width_policy = "max"

return full

def figure(self, **assignments):
return get_field_figure(self, "Field", None, **assignments)

def show(self, **assignments):
show(self.figure(**assignments))
39 changes: 38 additions & 1 deletion src/bloqade/ir/control/pulse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import List
from pydantic.dataclasses import dataclass
from ..tree_print import Printer

from bokeh.io import show
from bloqade.visualization.ir_visualize import get_pulse_figure

__all__ = [
"Pulse",
Expand Down Expand Up @@ -102,6 +103,15 @@ def slice(self, interval: Interval) -> "PulseExpr":
@staticmethod
def canonicalize(expr: "PulseExpr") -> "PulseExpr":
# TODO: update canonicalization rules for appending pulses
match expr:
case Append([Append(lhs), Append(rhs)]):
return Append(list(map(PulseExpr.canonicalize, lhs + rhs)))
case Append([Append(pulses), pulse]):
return PulseExpr.canonicalize(Append(pulses + [pulse]))
case Append([pulse, Append(pulses)]):
return PulseExpr.canonicalize(Append([pulse] + pulses))
case _:
return expr
return expr

def __repr__(self) -> str:
Expand All @@ -112,6 +122,15 @@ def __repr__(self) -> str:
def _repr_pretty_(self, p, cycle):
Printer(p).print(self, cycle)

def _get_data(self, **assigments):
return NotImplementedError

def figure(self, **assignments):
return NotImplementedError

def show(self, **assignments):
return NotImplementedError


@dataclass
class Append(PulseExpr):
Expand Down Expand Up @@ -167,6 +186,15 @@ def children(self):
}
return annotated_children

def _get_data(self, **assigments):
return None, self.value

def figure(self, **assignments):
return get_pulse_figure(self, **assignments)

def show(self, **assignments):
show(self.figure(**assignments))


@dataclass
class NamedPulse(PulseExpr):
Expand All @@ -182,6 +210,15 @@ def print_node(self):
def children(self):
return {"Name": self.name, "Pulse": self.pulse}

def _get_data(self, **assigments):
return self.name, self.pulse.value

def figure(self, **assignments):
return get_pulse_figure(self, **assignments)

def show(self, **assignments):
show(self.figure(**assignments))


@dataclass
class Slice(PulseExpr):
Expand Down
29 changes: 29 additions & 0 deletions src/bloqade/ir/control/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from pydantic.dataclasses import dataclass
from typing import List, Dict
from bloqade.visualization.ir_visualize import get_sequence_figure
from bokeh.io import show


__all__ = [
Expand Down Expand Up @@ -76,6 +78,15 @@ def __repr__(self) -> str:
def _repr_pretty_(self, p, cycle):
Printer(p).print(self, cycle)

def _get_data(self, **assignment):
raise NotImplementedError

def figure(self, **assignment):
raise NotImplementedError

def show(self, **assignment):
raise NotImplementedError


@dataclass
class Append(SequenceExpr):
Expand Down Expand Up @@ -131,6 +142,15 @@ def children(self):
def print_node(self):
return "Sequence"

def _get_data(self, **assignments):
return None, self.value

def figure(self, **assignments):
return get_sequence_figure(self, **assignments)

def show(self, **assignments):
show(self.figure(**assignments))


@dataclass
class NamedSequence(SequenceExpr):
Expand All @@ -146,6 +166,15 @@ def children(self):
def print_node(self):
return "NamedSequence"

def _get_data(self, **assignment):
return self.name, self.sequence.value

def figure(self, **assignments):
return get_sequence_figure(self, **assignments)

def show(self, **assignments):
show(self.figure(**assignments))


@dataclass(repr=False)
class Slice(SequenceExpr):
Expand Down
29 changes: 24 additions & 5 deletions src/bloqade/ir/control/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..tree_print import Printer
from ..scalar import Scalar, Interval, Variable, cast
from bokeh.plotting import figure
from bokeh.plotting import figure, show
import numpy as np
import inspect
import scipy.integrate as integrate
Expand Down Expand Up @@ -85,19 +85,35 @@ def add(self, other: "Waveform") -> "Waveform":
def append(self, other: "Waveform") -> "Waveform":
return self.canonicalize(Append([self, other]))

def plot(self, **assignments):
"""Plot the waveform.
def figure(self, **assignments):
"""get figure of the plotting the waveform.
Returns:
figure: a bokeh figure
"""
duration = self.duration(**assignments)
# Varlist = []
duration = float(self.duration(**assignments))
times = np.linspace(0, duration, 1001)
values = [self.__call__(time, **assignments) for time in times]
fig = figure(width=250, height=250)
fig = figure(
sizing_mode="stretch_both",
x_axis_label="Time (s)",
y_axis_label="Waveform(t)",
tools="hover",
)
fig.line(times, values)

return fig

def _get_data(self, npoints, **assignments):
duration = float(self.duration(**assignments))
times = np.linspace(0, duration, npoints + 1)
values = [self.__call__(time, **assignments) for time in times]
return times, values

def show(self, **assignments):
show(self.figure(**assignments))

def align(
self, alignment: Alignment, value: Union[None, AlignedValue, Scalar] = None
) -> "Waveform":
Expand Down Expand Up @@ -186,6 +202,9 @@ def __repr__(self) -> str:
def _repr_pretty_(self, p, cycle):
Printer(p).print(self, cycle)

def print_node(self):
raise NotImplementedError


@dataclass(repr=False)
class AlignedWaveform(Waveform):
Expand Down
10 changes: 10 additions & 0 deletions src/bloqade/ir/program.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from bloqade.ir import Sequence
from typing import TYPE_CHECKING, Union
from bokeh.io import show
from bokeh.layouts import row

if TYPE_CHECKING:
from bloqade.ir.location.base import AtomArrangement, ParallelRegister
Expand Down Expand Up @@ -46,3 +48,11 @@ def sequence(self):
"""
return self._sequence

def figure(self, **assignments):
fig_reg = self._register.figure(**assignments)
fig_seq = self._sequence.figure(**assignments)
return row(fig_seq, fig_reg)

def show(self, **assignments):
show(self.figure(**assignments))
6 changes: 3 additions & 3 deletions src/bloqade/ir/tree_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def __init__(self, p=None):
self.max_tree_depth = max_tree_depth

def should_print_annotation(self, children):
if (
type(children) == list or type(children) == tuple
if isinstance(
children, (list, tuple)
): # or generator, not sure of equivalence in Python
return False
elif type(children) == dict:
elif isinstance(children, dict):
return True

def get_value(self):
Expand Down
Loading

0 comments on commit b13d824

Please sign in to comment.