Skip to content

Commit

Permalink
Support symbolic parameters in QROM bloq
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar committed May 10, 2024
1 parent 8bb1199 commit 08794e5
Show file tree
Hide file tree
Showing 7 changed files with 464 additions and 85 deletions.
307 changes: 293 additions & 14 deletions qualtran/bloqs/data_loading/qrom.ipynb

Large diffs are not rendered by default.

138 changes: 102 additions & 36 deletions qualtran/bloqs/data_loading/qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,27 @@
# limitations under the License.

"""Quantum read-only memory."""

from functools import cached_property
from typing import Callable, Dict, Iterable, Iterator, Sequence, Set, Tuple
from typing import Callable, Dict, Iterable, Iterator, Sequence, Set, Tuple, TYPE_CHECKING, Union

import attrs
import cirq
import numpy as np
import sympy
from numpy.typing import ArrayLike, NDArray

from qualtran import bloq_example, BloqDocSpec, BoundedQUInt, QAny, Register
from qualtran._infra.gate_with_registers import merge_qubits, total_bits
from qualtran._infra.gate_with_registers import merge_qubits
from qualtran.bloqs.basic_gates import CNOT
from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd
from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate
from qualtran.drawing import Circle, TextBox, WireSymbol
from qualtran.resource_counting import BloqCountT
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import bit_length, is_symbolic, prod, shape, Shaped, SymbolicInt

if TYPE_CHECKING:
from qualtran.resource_counting import SympySymbolAllocator


def _to_tuple(x: Iterable[NDArray]) -> Sequence[NDArray]:
Expand Down Expand Up @@ -70,41 +74,81 @@ class QROM(UnaryIterationGate):
Babbush et. al. (2020). Figure 3.
"""

data: Sequence[NDArray] = attrs.field(converter=_to_tuple)
selection_bitsizes: Tuple[int, ...] = attrs.field(
data_or_shape: Union[NDArray, Shaped] = attrs.field(
converter=lambda x: np.array(x) if isinstance(x, (list, tuple)) else x
)
selection_bitsizes: Tuple[SymbolicInt, ...] = attrs.field(
converter=lambda x: tuple(x.tolist() if isinstance(x, np.ndarray) else x)
)
target_bitsizes: Tuple[int, ...] = attrs.field(
target_bitsizes: Tuple[SymbolicInt, ...] = attrs.field(
converter=lambda x: tuple(x.tolist() if isinstance(x, np.ndarray) else x)
)
num_controls: int = 0
num_controls: SymbolicInt = 0

def has_data(self) -> bool:
return not isinstance(self.data_or_shape, Shaped)

@property
def data_shape(self) -> Tuple[SymbolicInt, ...]:
return shape(self.data_or_shape)[1:]

@property
def data(self) -> np.ndarray:
if not self.has_data():
raise ValueError(f"Data not available for symbolic QROM {self}")
assert isinstance(self.data_or_shape, np.ndarray)
return self.data_or_shape

def __attrs_post_init__(self):
assert all([is_symbolic(s) or isinstance(s, int) for s in self.selection_bitsizes])
assert all([is_symbolic(t) or isinstance(t, int) for t in self.target_bitsizes])
assert len(self.target_bitsizes) == self.data_or_shape.shape[0], (
f"len(self.target_bitsizes)={len(self.target_bitsizes)} should be same as "
f"len(self.data)={self.data_or_shape.shape[0]}"
)
if isinstance(self.data_or_shape, np.ndarray) and not is_symbolic(*self.target_bitsizes):
assert all(
t >= int(np.max(d)).bit_length() for t, d in zip(self.target_bitsizes, self.data)
)
assert isinstance(self.selection_bitsizes, tuple)
assert isinstance(self.target_bitsizes, tuple)

@classmethod
def build(cls, *data: ArrayLike, num_controls: int = 0) -> 'QROM':
_data = [np.array(d, dtype=int) for d in data]
selection_bitsizes = tuple((s - 1).bit_length() for s in _data[0].shape)
def build_from_data(cls, *data: ArrayLike, num_controls: SymbolicInt = 0) -> 'QROM':
_data = np.array([np.array(d, dtype=int) for d in data])
selection_bitsizes = tuple((s - 1).bit_length() for s in _data.shape[1:])
target_bitsizes = tuple(max(int(np.max(d)).bit_length(), 1) for d in data)
return QROM(
data=_data,
data_or_shape=_data,
selection_bitsizes=selection_bitsizes,
target_bitsizes=target_bitsizes,
num_controls=num_controls,
)

def __attrs_post_init__(self):
shapes = [d.shape for d in self.data]
assert all([isinstance(s, int) for s in self.selection_bitsizes])
assert all([isinstance(t, int) for t in self.target_bitsizes])
assert len(set(shapes)) == 1, f"Data must all have the same size: {shapes}"
assert len(self.target_bitsizes) == len(self.data), (
f"len(self.target_bitsizes)={len(self.target_bitsizes)} should be same as "
f"len(self.data)={len(self.data)}"
@classmethod
def build_from_bitsize(
cls,
data_len_or_shape: Union[SymbolicInt, Tuple[SymbolicInt, ...]],
target_bitsizes: Union[SymbolicInt, Tuple[SymbolicInt, ...]],
*,
selection_bitsizes: Tuple[SymbolicInt, ...] = (),
num_controls: SymbolicInt = 0,
) -> 'QROM':
data_shape = (
(data_len_or_shape,) if isinstance(data_len_or_shape, int) else data_len_or_shape
)
assert all(
t >= int(np.max(d)).bit_length() for t, d in zip(self.target_bitsizes, self.data)
if not isinstance(target_bitsizes, tuple):
target_bitsizes = (target_bitsizes,)
_data = Shaped((len(target_bitsizes),) + data_shape)
if selection_bitsizes is ():
selection_bitsizes = tuple(bit_length(s - 1) for s in _data.shape[1:])
assert len(selection_bitsizes) == len(_data.shape) - 1
return QROM(
data_or_shape=_data,
selection_bitsizes=selection_bitsizes,
target_bitsizes=target_bitsizes,
num_controls=num_controls,
)
assert isinstance(self.selection_bitsizes, tuple)
assert isinstance(self.target_bitsizes, tuple)

@cached_property
def control_registers(self) -> Tuple[Register, ...]:
Expand All @@ -114,8 +158,8 @@ def control_registers(self) -> Tuple[Register, ...]:
def selection_registers(self) -> Tuple[Register, ...]:
types = [
BoundedQUInt(sb, l)
for l, sb in zip(self.data[0].shape, self.selection_bitsizes)
if sb > 0
for l, sb in zip(self.data_or_shape.shape[1:], self.selection_bitsizes)
if is_symbolic(sb) or sb > 0
]
if len(types) == 1:
return (Register('selection', types[0]),)
Expand All @@ -124,7 +168,9 @@ def selection_registers(self) -> Tuple[Register, ...]:
@cached_property
def target_registers(self) -> Tuple[Register, ...]:
return tuple(
Register(f'target{i}_', QAny(l)) for i, l in enumerate(self.target_bitsizes) if l
Register(f'target{i}_', QAny(l))
for i, l in enumerate(self.target_bitsizes)
if is_symbolic(l) or l
)

def _load_nth_data(
Expand All @@ -144,7 +190,7 @@ def decompose_zero_selection(
) -> Iterator[cirq.OP_TREE]:
controls = merge_qubits(self.control_registers, **quregs)
target_regs = {reg.name: quregs[reg.name] for reg in self.target_registers}
zero_indx = (0,) * len(self.data[0].shape)
zero_indx = (0,) * len(self.data_shape)
if self.num_controls == 0:
yield self._load_nth_data(zero_indx, cirq.X, **target_regs)
elif self.num_controls == 1:
Expand All @@ -165,6 +211,9 @@ def decompose_zero_selection(
context.qubit_manager.qfree(list(junk.flatten()) + [and_target])

def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int):
if not self.has_data():
return False

for data in self.data:
data_l_r_flat = data[selection_index_prefix][l:r].flat
unique_element = data_l_r_flat[0]
Expand All @@ -181,6 +230,9 @@ def nth_operation(
yield self._load_nth_data(selection_idx, lambda q: CNOT().on(control, q), **target_regs)

def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
if not self.has_data():
raise NotImplementedError(f'Symbolic {self} does not support classical simulation')

if self.num_controls > 0:
control = vals['control']
if control != 2**self.num_controls - 1:
Expand All @@ -203,12 +255,10 @@ def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT
targets = {k: v ^ vals[k] for k, v in targets.items()}
return controls | selections | targets

def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
wire_symbols = ["@"] * self.num_controls
wire_symbols += ["In"] * total_bits(self.selection_registers)
for i, target in enumerate(self.target_registers):
wire_symbols += [f"QROM_{i}"] * target.total_bits()
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
def _circuit_diagram_info_(self, args) -> cirq.CircuitDiagramInfo:
from qualtran.cirq_interop._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info

return _wire_symbol_to_cirq_diagram_info(self, args)

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
name = reg.name
Expand All @@ -223,7 +273,7 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym
trg_indx = int(name.replace('target', '').replace('_', ''))
# match the sel index
subscript = chr(ord('a') + trg_indx)
return TextBox(f'data_{subscript}')
return TextBox(f'QROM_{subscript}')
elif name == 'control':
return Circle()
raise ValueError(f'Unrecognized register name {name}')
Expand All @@ -234,13 +284,22 @@ def __pow__(self, power: int):
return NotImplemented # pragma: no cover

def _value_equality_values_(self):
data_tuple = tuple(tuple(d.flatten()) for d in self.data)
data_tuple = (
tuple(tuple(d.flatten()) for d in self.data) if self.has_data() else self.data_or_shape
)
return (self.selection_registers, self.target_registers, self.control_registers, data_tuple)

def nth_operation_callgraph(self, **kwargs: int) -> Set['BloqCountT']:
selection_idx = tuple(kwargs[reg.name] for reg in self.selection_registers)
return {(CNOT(), sum(int(d[selection_idx]).bit_count() for d in self.data))}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
if self.has_data():
return super().build_call_graph(ssa=ssa)
n_and = prod(*self.data_shape) - 2 + self.num_controls
n_cnot = prod(*self.target_bitsizes, *self.data_shape)
return {(And(), n_and), (And().adjoint(), n_and), (CNOT(), n_cnot)}


@bloq_example
def _qrom_small() -> QROM:
Expand All @@ -265,8 +324,15 @@ def _qrom_multi_dim() -> QROM:
return qrom_multi_dim


@bloq_example
def _qrom_symb() -> QROM:
N, M, b1, b2, c = sympy.symbols('N M b1 b2 c')
qrom_symb = QROM.build_from_bitsize((N, M), (b1, b2), num_controls=c)
return qrom_symb


_QROM_DOC = BloqDocSpec(
bloq_cls=QROM,
import_line='from qualtran.bloqs.data_loading.qrom import QROM',
examples=[_qrom_small, _qrom_multi_data, _qrom_multi_dim],
examples=[_qrom_small, _qrom_multi_data, _qrom_multi_dim, _qrom_symb],
)
Loading

0 comments on commit 08794e5

Please sign in to comment.