Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove short_name method #934

Merged
merged 11 commits into from
May 14, 2024
15 changes: 9 additions & 6 deletions qualtran/_infra/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, List, Set, Tuple, TYPE_CHECKING
from typing import cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING

import cirq
from attrs import frozen
Expand Down Expand Up @@ -170,10 +170,6 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
"""The call graph takes the adjoint of each of the bloqs in `subbloq`'s call graph."""
return {(bloq.adjoint(), n) for bloq, n in self.subbloq.build_call_graph(ssa=ssa)}

def short_name(self) -> str:
"""The subbloq's short_name with a dagger."""
return self.subbloq.short_name() + '†'

def pretty_name(self) -> str:
"""The subbloq's pretty_name with a dagger."""
return self.subbloq.pretty_name() + '†'
Expand All @@ -182,10 +178,17 @@ def __str__(self) -> str:
"""Delegate to subbloq's `__str__` method."""
return f'Adjoint(subbloq={str(self.subbloq)})'

def wire_symbol(self, reg: 'Register', idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
# Note: since we pass are passed a soquet which has the 'new' side, we flip it before
# delegating and then flip back. Subbloqs only have to answer this protocol
# if the provided soquet is facing the correct direction.
from qualtran.drawing import Text

if reg is None:
return Text(cast(Text, self.subbloq.wire_symbol(reg=None)).text + '†')

return self.subbloq.wire_symbol(reg=reg.adjoint(), idx=idx).adjoint()

def _t_complexity_(self):
Expand Down
8 changes: 4 additions & 4 deletions qualtran/_infra/adjoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Dict, TYPE_CHECKING
from typing import cast, Dict, TYPE_CHECKING

import pytest
import sympy
Expand All @@ -25,7 +25,7 @@
from qualtran.bloqs.for_testing.with_call_graph import TestBloqWithCallGraph
from qualtran.bloqs.for_testing.with_decomposition import TestParallelCombo, TestSerialCombo
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import LarrowTextBox, RarrowTextBox
from qualtran.drawing import LarrowTextBox, RarrowTextBox, Text

if TYPE_CHECKING:
from qualtran import BloqBuilder, SoquetT
Expand Down Expand Up @@ -149,11 +149,11 @@ def test_call_graph():
def test_names():
atom = TestAtom()
assert atom.pretty_name() == "TestAtom"
assert atom.short_name() == "Atom"
assert cast(Text, atom.wire_symbol(reg=None)).text == "TestAtom"

adj_atom = Adjoint(atom)
assert adj_atom.pretty_name() == "TestAtom†"
assert adj_atom.short_name() == "Atom†"
assert cast(Text, adj_atom.wire_symbol(reg=None)).text == "TestAtom†"
assert str(adj_atom) == "Adjoint(subbloq=TestAtom())"


Expand Down
27 changes: 17 additions & 10 deletions qualtran/_infra/bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,6 @@ def signature(self) -> 'Signature':
def pretty_name(self) -> str:
return self.__class__.__name__

def short_name(self) -> str:
name = self.pretty_name()
if len(name) <= 10:
return name

return name[:8] + '..'

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
"""Override this method to define a Bloq in terms of its constituent parts.

Expand Down Expand Up @@ -282,7 +275,9 @@ def add_my_tensors(
from qualtran.simulation.tensor import cbloq_as_contracted_tensor

cbloq = self.decompose_bloq()
tn.add(cbloq_as_contracted_tensor(cbloq, incoming, outgoing, tags=[self.short_name(), tag]))
tn.add(
cbloq_as_contracted_tensor(cbloq, incoming, outgoing, tags=[self.pretty_name(), tag])
)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
"""Override this method to build the bloq call graph.
Expand Down Expand Up @@ -508,18 +503,30 @@ def on_registers(

return self.on(*merge_qubits(self.signature, **qubit_regs))

def wire_symbol(self, reg: 'Register', idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
"""On a musical score visualization, use this `WireSymbol` to represent `soq`.

By default, we use a "directional text box", which is a text box that is either
rectangular for thru-registers or facing to the left or right for non-thru-registers.

If reg is specified as `None`, this should return a Text label for the title of
the gate. If no title is needed (as the wire_symbols are self-explanatory),
this should return `Text('')`.

Override this method to provide a more relevant `WireSymbol` for the provided soquet.
This method can access bloq attributes. For example: you may want to draw either
a filled or empty circle for a control register depending on a control value bloq
attribute.
"""
from qualtran.drawing import directional_text_box
from qualtran.drawing import directional_text_box, Text

if reg is None:
name = self.pretty_name()
if len(name) <= 10:
return Text(name)
return Text(name[:8] + '..')

label = reg.name
if len(idx) > 0:
Expand Down
13 changes: 6 additions & 7 deletions qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def add_my_tensors(
subbloq_shape = tensor_shape_from_signature(self.subbloq.signature)
data[active_idx] = self.subbloq.tensor_contract().reshape(subbloq_shape)
# Add the data to the tensor network.
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.short_name(), tag]))
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.pretty_name(), tag]))

def _unitary_(self):
if isinstance(self.subbloq, GateWithRegisters):
Expand All @@ -433,11 +433,13 @@ def _unitary_(self):
# Unable to determine the unitary effect.
return NotImplemented

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
from qualtran.drawing import Text

if reg is None:
return Text(f'C[{self.subbloq.wire_symbol(reg=None)}]')
if reg.name not in self.ctrl_reg_names:
# Delegate to subbloq
print(self.subbloq)
print(type(self.subbloq))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops

return self.subbloq.wire_symbol(reg, idx)

# Otherwise, it's part of the control register.
Expand All @@ -450,9 +452,6 @@ def adjoint(self) -> 'Bloq':
def pretty_name(self) -> str:
return f'C[{self.subbloq.pretty_name()}]'

def short_name(self) -> str:
return f'C[{self.subbloq.short_name()}]'

def __str__(self) -> str:
return f'C[{self.subbloq}]'

Expand Down
14 changes: 6 additions & 8 deletions qualtran/_infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,12 @@ def as_cirq_op(
)
return self.on_registers(**all_quregs), out_quregs

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
from qualtran.cirq_interop._cirq_to_bloq import _wire_symbol_from_gate
from qualtran.drawing import Text

if reg is None:
return Text(self.pretty_name())

return _wire_symbol_from_gate(self, self.signature, reg, idx)

Expand Down Expand Up @@ -515,13 +519,7 @@ def add_my_tensors(
from qualtran.cirq_interop._cirq_to_bloq import _add_my_tensors_from_gate

_add_my_tensors_from_gate(
self,
self.signature,
self.short_name(),
tn,
tag,
incoming=incoming,
outgoing=outgoing,
self, self.signature, str(self), tn, tag, incoming=incoming, outgoing=outgoing
)
else:
return super().add_my_tensors(tn, tag, incoming=incoming, outgoing=outgoing)
Expand Down
15 changes: 7 additions & 8 deletions qualtran/bloqs/arithmetic/addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def add_my_tensors(
for a, b in itertools.product(range(N_a), range(N_b)):
unitary[a, b, a, int(math.fmod(a + b, N_b))] = 1

tn.add(qtn.Tensor(data=unitary, inds=inds, tags=[self.short_name(), tag]))
tn.add(qtn.Tensor(data=unitary, inds=inds, tags=[self.pretty_name(), tag]))

def decompose_bloq(self) -> 'CompositeBloq':
return decompose_from_cirq_style_method(self)
Expand All @@ -155,17 +155,16 @@ def on_classical_vals(
N = 2**b_bitsize if unsigned else 2 ** (b_bitsize - 1)
return {'a': a, 'b': int(math.fmod(a + b, N))}

def short_name(self) -> str:
return "a+b"

def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
wire_symbols = ["In(x)"] * int(self.a_dtype.bitsize)
wire_symbols += ["In(y)/Out(x+y)"] * int(self.b_dtype.bitsize)
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
from qualtran.drawing import directional_text_box
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
from qualtran.drawing import directional_text_box, Text

if reg is None:
return Text("a+b")
if reg.name == 'a':
return directional_text_box('a', side=reg.side)
elif reg.name == 'b':
Expand Down Expand Up @@ -318,7 +317,7 @@ def on_classical_vals(
def with_registers(self, *new_registers: Union[int, Sequence[int]]):
raise NotImplementedError("no need to implement with_registers.")

def short_name(self) -> str:
def pretty_name(self) -> str:
return "c = a + b"

def decompose_from_registers(
Expand Down Expand Up @@ -501,7 +500,7 @@ def build_composite_bloq(
else:
return {'x': x}

def short_name(self) -> str:
def pretty_name(self) -> str:
return f'x += {self.k}'


Expand Down
60 changes: 40 additions & 20 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,18 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, Iterable, Iterator, List, Sequence, Set, Tuple, TYPE_CHECKING, Union
from typing import (
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)

import attrs
import cirq
Expand Down Expand Up @@ -42,7 +53,7 @@
from qualtran.cirq_interop.bit_tools import iter_bits
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
from qualtran.drawing import WireSymbol
from qualtran.drawing.musical_score import TextBox
from qualtran.drawing.musical_score import Text, TextBox

if TYPE_CHECKING:
from qualtran import BloqBuilder
Expand All @@ -62,8 +73,12 @@ class LessThanConstant(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[
def signature(self) -> Signature:
return Signature.build_from_dtypes(x=QUInt(self.bitsize), target=QBit())

def short_name(self) -> str:
return f'x<{self.less_than_val}'
def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
if reg is None:
return Text(f'x<{self.less_than_val}')
return super().wire_symbol(reg, idx)

def registers(self) -> Sequence[Union[int, Sequence[int]]]:
return [2] * self.bitsize, self.less_than_val, [2]
Expand Down Expand Up @@ -428,8 +443,12 @@ def apply(self, *register_vals: int) -> Union[int, int, Iterable[int]]:
x_val, y_val, target_val = register_vals
return x_val, y_val, target_val ^ (x_val <= y_val)

def short_name(self) -> str:
return 'x <= y'
def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
if reg is None:
return Text('x <= y')
return super().wire_symbol(reg, idx)

def on_classical_vals(self, *, x: int, y: int, target: int) -> Dict[str, 'ClassicalValT']:
return {'x': x, 'y': y, 'target': target ^ (x <= y)}
Expand Down Expand Up @@ -599,16 +618,15 @@ def signature(self):
a=QUInt(self.a_bitsize), b=QUInt(self.b_bitsize), target=QBit()
)

def short_name(self) -> str:
return "a>b"

def _t_complexity_(self) -> 'TComplexity':
# TODO Determine precise clifford count and/or ignore.
# See: https://github.com/quantumlib/Qualtran/issues/219
# See: https://github.com/quantumlib/Qualtran/issues/217
return t_complexity(LessThanEqual(self.a_bitsize, self.b_bitsize))

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol:
if reg is None:
return Text("a>b")
if reg.name == 'a':
return TextBox("In(a)")
if reg.name == 'b':
Expand Down Expand Up @@ -799,8 +817,12 @@ def build_composite_bloq(
# Return the output registers.
return {'a': a, 'b': b, 'target': target}

def short_name(self) -> str:
return "a > b"
def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
if reg is None:
return Text('a > b')
return super().wire_symbol(reg, idx)


@frozen
Expand Down Expand Up @@ -836,10 +858,9 @@ def _t_complexity_(self) -> TComplexity:
# See: https://github.com/quantumlib/Qualtran/issues/217
return t_complexity(LessThanConstant(self.bitsize, less_than_val=self.val))

def short_name(self) -> str:
return f"x > {self.val}"

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol:
if reg is None:
return Text(f"x > {self.val}")
if reg.name == 'x':
return TextBox("In(x)")
elif reg.name == 'target':
Expand Down Expand Up @@ -889,10 +910,9 @@ def signature(self) -> Signature:
def _t_complexity_(self) -> 'TComplexity':
return TComplexity(t=4 * (self.bitsize - 1))

def short_name(self) -> str:
return f"x == {self.val}"

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol:
if reg is None:
return Text(f"x == {self.val}")
if reg.name == 'x':
return TextBox("In(x)")
elif reg.name == 'target':
Expand Down
Loading
Loading