From 08794e5c986b2115e401a6528059f22a44c78718 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Fri, 10 May 2024 16:48:53 -0700 Subject: [PATCH] Support symbolic parameters in QROM bloq --- qualtran/bloqs/data_loading/qrom.ipynb | 307 +++++++++++++++++- qualtran/bloqs/data_loading/qrom.py | 138 ++++++-- qualtran/bloqs/data_loading/qrom_test.py | 65 ++-- .../state_preparation_alias_sampling_test.py | 10 +- .../resource_counting/classify_bloqs_test.py | 4 +- qualtran/symbolics/__init__.py | 2 + qualtran/symbolics/math_funcs.py | 23 +- 7 files changed, 464 insertions(+), 85 deletions(-) diff --git a/qualtran/bloqs/data_loading/qrom.ipynb b/qualtran/bloqs/data_loading/qrom.ipynb index 06da48b7a..053285d59 100644 --- a/qualtran/bloqs/data_loading/qrom.ipynb +++ b/qualtran/bloqs/data_loading/qrom.ipynb @@ -14,12 +14,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "2737c79f", "metadata": { "cq.autogen": "top_imports" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/tanujkhattar/opt/anaconda3/envs/qualtran2/lib/python3.11/site-packages/cotengra/hyperoptimizers/hyper.py:34: UserWarning: Couldn't import `kahypar` - skipping from default hyper optimizer and using basic `labels` method instead.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from qualtran import Bloq, CompositeBloq, BloqBuilder, Signature, Register\n", "from qualtran import QBit, QInt, QUInt, QAny\n", @@ -59,12 +68,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "af317b30", "metadata": { "cq.autogen": "QROM.bloq_doc.py" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/tanujkhattar/quantum/Qualtran/qualtran/bloqs/data_loading/qrom.py:143: SyntaxWarning: \"is\" with a literal. Did you mean \"==\"?\n", + " if selection_bitsizes is ():\n" + ] + } + ], "source": [ "from qualtran.bloqs.data_loading.qrom import QROM" ] @@ -81,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "5f02d641", "metadata": { "cq.autogen": "QROM.qrom_small" @@ -94,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "c2f8b350", "metadata": { "cq.autogen": "QROM.qrom_multi_data" @@ -108,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "036cf220", "metadata": { "cq.autogen": "QROM.qrom_multi_dim" @@ -120,6 +138,29 @@ "qrom_multi_dim = QROM([data1, data2], selection_bitsizes=(2, 2), target_bitsizes=(8, 8))" ] }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a084ca8c-0c89-4439-86d9-51cf91e972c4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "QROM(data_or_shape=Shaped(shape=(2, N, M)), selection_bitsizes=(ceiling(log2(N - 1)), ceiling(log2(M - 1))), target_bitsizes=(b1, b2), num_controls=c)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "N, M, b1, b2, c = sympy.symbols('N M b1 b2 c')\n", + "qrom_symb = QROM.build_from_bitsize((N, M), (b1, b2), num_controls=c)\n", + "qrom_symb" + ] + }, { "cell_type": "markdown", "id": "b92d1c8e", @@ -132,16 +173,31 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "9681cfed", "metadata": { "cq.autogen": "QROM.graphical_signature.py" }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ddfa1ff6362e457bb5450ef591c98e62", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(Output(outputs=({'output_type': 'display_data', 'data': {'text/plain': '\n", + "\n", + "counts\n", + "\n", + "\n", + "\n", + "b0\n", + "\n", + "QROM\n", + "data_or_shape=array([[ ..., selection_bitsizes=(3,), target_bitsizes=(3,), num_controls=0\n", + "\n", + "\n", + "\n", + "b1\n", + "\n", + "And\n", + "cv1=1, cv2=0, uncompute=False\n", + "\n", + "\n", + "\n", + "b0->b1\n", + "\n", + "\n", + "3\n", + "\n", + "\n", + "\n", + "b2\n", + "\n", + "XGate\n", + "\n", + "\n", + "\n", + "b0->b2\n", + "\n", + "\n", + "2\n", + "\n", + "\n", + "\n", + "b3\n", + "\n", + "CNOT\n", + "\n", + "\n", + "\n", + "b0->b3\n", + "\n", + "\n", + "8\n", + "\n", + "\n", + "\n", + "b4\n", + "\n", + "And†\n", + "cv1=1, cv2=1, uncompute=True\n", + "\n", + "\n", + "\n", + "b0->b4\n", + "\n", + "\n", + "3\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "#### Counts totals:\n", + " - `And(cv1=1, cv2=0, uncompute=False)`: 3\n", + " - `And(cv1=1, cv2=1, uncompute=True)`: 3\n", + " - `CNOT()`: 8\n", + " - `XGate()`: 2" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "from qualtran.resource_counting.generalizers import ignore_split_join\n", "qrom_small_g, qrom_small_sigma = qrom_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", "show_call_graph(qrom_small_g)\n", "show_counts_sigma(qrom_small_sigma)" ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "5392752e-276e-434c-9d60-fadcf6478077", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "counts\n", + "\n", + "\n", + "\n", + "b0\n", + "\n", + "QROM\n", + "data_or_shape=Shaped(s ..., selection_bitsizes=(ceiling ..., target_bitsizes=(b1, b2), num_controls=c\n", + "\n", + "\n", + "\n", + "b1\n", + "\n", + "And\n", + "cv1=1, cv2=1, uncompute=False\n", + "\n", + "\n", + "\n", + "b0->b1\n", + "\n", + "\n", + "M⋅N + c - 2\n", + "\n", + "\n", + "\n", + "b2\n", + "\n", + "CNOT\n", + "\n", + "\n", + "\n", + "b0->b2\n", + "\n", + "\n", + "M⋅N⋅b₁⋅b₂\n", + "\n", + "\n", + "\n", + "b3\n", + "\n", + "And†\n", + "cv1=1, cv2=1, uncompute=True\n", + "\n", + "\n", + "\n", + "b0->b3\n", + "\n", + "\n", + "M⋅N + c - 2\n", + "\n", + "\n", + "\n", + "b4\n", + "\n", + "T\n", + "is_adjoint=False\n", + "\n", + "\n", + "\n", + "b1->b4\n", + "\n", + "\n", + "4\n", + "\n", + "\n", + "\n", + "b5\n", + "\n", + "ArbitraryClifford\n", + "n=2\n", + "\n", + "\n", + "\n", + "b1->b5\n", + "\n", + "\n", + "9\n", + "\n", + "\n", + "\n", + "b3->b5\n", + "\n", + "\n", + "4\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "#### Counts totals:\n", + " - `ArbitraryClifford(n=2)`: $\\displaystyle 13 M N + 13 c - 26$\n", + " - `CNOT()`: $\\displaystyle M N b_{1} b_{2}$\n", + " - `TGate()`: $\\displaystyle 4 M N + 4 c - 8$" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "qrom_symb_g, qrom_symb_sigma = qrom_symb.call_graph(generalizer=ignore_split_join)\n", + "show_call_graph(qrom_symb_g)\n", + "show_counts_sigma(qrom_symb_sigma)" + ] } ], "metadata": { @@ -186,7 +465,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/qualtran/bloqs/data_loading/qrom.py b/qualtran/bloqs/data_loading/qrom.py index fa45df25a..82d2523c1 100644 --- a/qualtran/bloqs/data_loading/qrom.py +++ b/qualtran/bloqs/data_loading/qrom.py @@ -13,23 +13,27 @@ # limitations under the License. """Quantum read-only memory.""" - from functools import cached_property -from typing import Callable, Dict, Iterable, Iterator, Sequence, Set, Tuple +from typing import Callable, Dict, Iterable, Iterator, Sequence, Set, Tuple, TYPE_CHECKING, Union import attrs import cirq import numpy as np +import sympy from numpy.typing import ArrayLike, NDArray from qualtran import bloq_example, BloqDocSpec, BoundedQUInt, QAny, Register -from qualtran._infra.gate_with_registers import merge_qubits, total_bits +from qualtran._infra.gate_with_registers import merge_qubits from qualtran.bloqs.basic_gates import CNOT from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate from qualtran.drawing import Circle, TextBox, WireSymbol from qualtran.resource_counting import BloqCountT from qualtran.simulation.classical_sim import ClassicalValT +from qualtran.symbolics import bit_length, is_symbolic, prod, shape, Shaped, SymbolicInt + +if TYPE_CHECKING: + from qualtran.resource_counting import SympySymbolAllocator def _to_tuple(x: Iterable[NDArray]) -> Sequence[NDArray]: @@ -70,41 +74,81 @@ class QROM(UnaryIterationGate): Babbush et. al. (2020). Figure 3. """ - data: Sequence[NDArray] = attrs.field(converter=_to_tuple) - selection_bitsizes: Tuple[int, ...] = attrs.field( + data_or_shape: Union[NDArray, Shaped] = attrs.field( + converter=lambda x: np.array(x) if isinstance(x, (list, tuple)) else x + ) + selection_bitsizes: Tuple[SymbolicInt, ...] = attrs.field( converter=lambda x: tuple(x.tolist() if isinstance(x, np.ndarray) else x) ) - target_bitsizes: Tuple[int, ...] = attrs.field( + target_bitsizes: Tuple[SymbolicInt, ...] = attrs.field( converter=lambda x: tuple(x.tolist() if isinstance(x, np.ndarray) else x) ) - num_controls: int = 0 + num_controls: SymbolicInt = 0 + + def has_data(self) -> bool: + return not isinstance(self.data_or_shape, Shaped) + + @property + def data_shape(self) -> Tuple[SymbolicInt, ...]: + return shape(self.data_or_shape)[1:] + + @property + def data(self) -> np.ndarray: + if not self.has_data(): + raise ValueError(f"Data not available for symbolic QROM {self}") + assert isinstance(self.data_or_shape, np.ndarray) + return self.data_or_shape + + def __attrs_post_init__(self): + assert all([is_symbolic(s) or isinstance(s, int) for s in self.selection_bitsizes]) + assert all([is_symbolic(t) or isinstance(t, int) for t in self.target_bitsizes]) + assert len(self.target_bitsizes) == self.data_or_shape.shape[0], ( + f"len(self.target_bitsizes)={len(self.target_bitsizes)} should be same as " + f"len(self.data)={self.data_or_shape.shape[0]}" + ) + if isinstance(self.data_or_shape, np.ndarray) and not is_symbolic(*self.target_bitsizes): + assert all( + t >= int(np.max(d)).bit_length() for t, d in zip(self.target_bitsizes, self.data) + ) + assert isinstance(self.selection_bitsizes, tuple) + assert isinstance(self.target_bitsizes, tuple) @classmethod - def build(cls, *data: ArrayLike, num_controls: int = 0) -> 'QROM': - _data = [np.array(d, dtype=int) for d in data] - selection_bitsizes = tuple((s - 1).bit_length() for s in _data[0].shape) + def build_from_data(cls, *data: ArrayLike, num_controls: SymbolicInt = 0) -> 'QROM': + _data = np.array([np.array(d, dtype=int) for d in data]) + selection_bitsizes = tuple((s - 1).bit_length() for s in _data.shape[1:]) target_bitsizes = tuple(max(int(np.max(d)).bit_length(), 1) for d in data) return QROM( - data=_data, + data_or_shape=_data, selection_bitsizes=selection_bitsizes, target_bitsizes=target_bitsizes, num_controls=num_controls, ) - def __attrs_post_init__(self): - shapes = [d.shape for d in self.data] - assert all([isinstance(s, int) for s in self.selection_bitsizes]) - assert all([isinstance(t, int) for t in self.target_bitsizes]) - assert len(set(shapes)) == 1, f"Data must all have the same size: {shapes}" - assert len(self.target_bitsizes) == len(self.data), ( - f"len(self.target_bitsizes)={len(self.target_bitsizes)} should be same as " - f"len(self.data)={len(self.data)}" + @classmethod + def build_from_bitsize( + cls, + data_len_or_shape: Union[SymbolicInt, Tuple[SymbolicInt, ...]], + target_bitsizes: Union[SymbolicInt, Tuple[SymbolicInt, ...]], + *, + selection_bitsizes: Tuple[SymbolicInt, ...] = (), + num_controls: SymbolicInt = 0, + ) -> 'QROM': + data_shape = ( + (data_len_or_shape,) if isinstance(data_len_or_shape, int) else data_len_or_shape ) - assert all( - t >= int(np.max(d)).bit_length() for t, d in zip(self.target_bitsizes, self.data) + if not isinstance(target_bitsizes, tuple): + target_bitsizes = (target_bitsizes,) + _data = Shaped((len(target_bitsizes),) + data_shape) + if selection_bitsizes is (): + selection_bitsizes = tuple(bit_length(s - 1) for s in _data.shape[1:]) + assert len(selection_bitsizes) == len(_data.shape) - 1 + return QROM( + data_or_shape=_data, + selection_bitsizes=selection_bitsizes, + target_bitsizes=target_bitsizes, + num_controls=num_controls, ) - assert isinstance(self.selection_bitsizes, tuple) - assert isinstance(self.target_bitsizes, tuple) @cached_property def control_registers(self) -> Tuple[Register, ...]: @@ -114,8 +158,8 @@ def control_registers(self) -> Tuple[Register, ...]: def selection_registers(self) -> Tuple[Register, ...]: types = [ BoundedQUInt(sb, l) - for l, sb in zip(self.data[0].shape, self.selection_bitsizes) - if sb > 0 + for l, sb in zip(self.data_or_shape.shape[1:], self.selection_bitsizes) + if is_symbolic(sb) or sb > 0 ] if len(types) == 1: return (Register('selection', types[0]),) @@ -124,7 +168,9 @@ def selection_registers(self) -> Tuple[Register, ...]: @cached_property def target_registers(self) -> Tuple[Register, ...]: return tuple( - Register(f'target{i}_', QAny(l)) for i, l in enumerate(self.target_bitsizes) if l + Register(f'target{i}_', QAny(l)) + for i, l in enumerate(self.target_bitsizes) + if is_symbolic(l) or l ) def _load_nth_data( @@ -144,7 +190,7 @@ def decompose_zero_selection( ) -> Iterator[cirq.OP_TREE]: controls = merge_qubits(self.control_registers, **quregs) target_regs = {reg.name: quregs[reg.name] for reg in self.target_registers} - zero_indx = (0,) * len(self.data[0].shape) + zero_indx = (0,) * len(self.data_shape) if self.num_controls == 0: yield self._load_nth_data(zero_indx, cirq.X, **target_regs) elif self.num_controls == 1: @@ -165,6 +211,9 @@ def decompose_zero_selection( context.qubit_manager.qfree(list(junk.flatten()) + [and_target]) def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int): + if not self.has_data(): + return False + for data in self.data: data_l_r_flat = data[selection_index_prefix][l:r].flat unique_element = data_l_r_flat[0] @@ -181,6 +230,9 @@ def nth_operation( yield self._load_nth_data(selection_idx, lambda q: CNOT().on(control, q), **target_regs) def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT']: + if not self.has_data(): + raise NotImplementedError(f'Symbolic {self} does not support classical simulation') + if self.num_controls > 0: control = vals['control'] if control != 2**self.num_controls - 1: @@ -203,12 +255,10 @@ def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT targets = {k: v ^ vals[k] for k, v in targets.items()} return controls | selections | targets - def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: - wire_symbols = ["@"] * self.num_controls - wire_symbols += ["In"] * total_bits(self.selection_registers) - for i, target in enumerate(self.target_registers): - wire_symbols += [f"QROM_{i}"] * target.total_bits() - return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) + def _circuit_diagram_info_(self, args) -> cirq.CircuitDiagramInfo: + from qualtran.cirq_interop._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info + + return _wire_symbol_to_cirq_diagram_info(self, args) def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol': name = reg.name @@ -223,7 +273,7 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym trg_indx = int(name.replace('target', '').replace('_', '')) # match the sel index subscript = chr(ord('a') + trg_indx) - return TextBox(f'data_{subscript}') + return TextBox(f'QROM_{subscript}') elif name == 'control': return Circle() raise ValueError(f'Unrecognized register name {name}') @@ -234,13 +284,22 @@ def __pow__(self, power: int): return NotImplemented # pragma: no cover def _value_equality_values_(self): - data_tuple = tuple(tuple(d.flatten()) for d in self.data) + data_tuple = ( + tuple(tuple(d.flatten()) for d in self.data) if self.has_data() else self.data_or_shape + ) return (self.selection_registers, self.target_registers, self.control_registers, data_tuple) def nth_operation_callgraph(self, **kwargs: int) -> Set['BloqCountT']: selection_idx = tuple(kwargs[reg.name] for reg in self.selection_registers) return {(CNOT(), sum(int(d[selection_idx]).bit_count() for d in self.data))} + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: + if self.has_data(): + return super().build_call_graph(ssa=ssa) + n_and = prod(*self.data_shape) - 2 + self.num_controls + n_cnot = prod(*self.target_bitsizes, *self.data_shape) + return {(And(), n_and), (And().adjoint(), n_and), (CNOT(), n_cnot)} + @bloq_example def _qrom_small() -> QROM: @@ -265,8 +324,15 @@ def _qrom_multi_dim() -> QROM: return qrom_multi_dim +@bloq_example +def _qrom_symb() -> QROM: + N, M, b1, b2, c = sympy.symbols('N M b1 b2 c') + qrom_symb = QROM.build_from_bitsize((N, M), (b1, b2), num_controls=c) + return qrom_symb + + _QROM_DOC = BloqDocSpec( bloq_cls=QROM, import_line='from qualtran.bloqs.data_loading.qrom import QROM', - examples=[_qrom_small, _qrom_multi_data, _qrom_multi_dim], + examples=[_qrom_small, _qrom_multi_data, _qrom_multi_dim, _qrom_symb], ) diff --git a/qualtran/bloqs/data_loading/qrom_test.py b/qualtran/bloqs/data_loading/qrom_test.py index 792e5fee9..76fbe4908 100644 --- a/qualtran/bloqs/data_loading/qrom_test.py +++ b/qualtran/bloqs/data_loading/qrom_test.py @@ -17,6 +17,7 @@ import cirq import numpy as np import pytest +import sympy from qualtran._infra.gate_with_registers import split_qubits, total_bits from qualtran.bloqs.basic_gates import CNOT, TGate @@ -56,7 +57,7 @@ def test_qrom_multi_dim(bloq_autotester): ], ) def test_qrom_1d_full(data, num_controls): - qrom = QROM.build(*data, num_controls=num_controls) + qrom = QROM.build_from_data(*data, num_controls=num_controls) assert_valid_bloq_decomposition(qrom) greedy_mm = cirq.GreedyQubitManager('a', maximize_reuse=True) @@ -177,7 +178,7 @@ def test_qrom_3d_classical(): def test_qrom_diagram(): d0 = np.array([1, 2, 3]) d1 = np.array([4, 5, 6]) - qrom = QROM.build(d0, d1) + qrom = QROM.build_from_data(d0, d1) q = cirq.LineQubit.range(cirq.num_qubits(qrom)) circuit = cirq.Circuit(qrom.on_registers(**split_qubits(qrom.signature, q))) cirq.testing.assert_has_diagram( @@ -187,15 +188,15 @@ def test_qrom_diagram(): │ 1: ───In─────── │ -2: ───QROM_0─── +2: ───QROM_a─── │ -3: ───QROM_0─── +3: ───QROM_a─── │ -4: ───QROM_1─── +4: ───QROM_b─── │ -5: ───QROM_1─── +5: ───QROM_b─── │ -6: ───QROM_1───""", +6: ───QROM_b───""", ) @@ -208,11 +209,21 @@ def test_notebook(): "data", [[[1, 2, 3, 4, 5]], [[1, 2, 3], [4, 5, 10]], [[1], [2], [3], [4], [5], [6]]] ) def test_t_complexity(data): - qrom = QROM.build(*data) + qrom = QROM.build_from_data(*data) n = np.prod(qrom.data[0].shape) assert t_complexity(qrom).t == max(0, 4 * n - 8), n +def test_t_complexity_symbolic(): + N, M = sympy.symbols('N M') + b1, b2 = sympy.symbols('b1 b2') + c = sympy.Symbol('c') + qrom_symb = QROM.build_from_bitsize((N, M), (b1, b2), num_controls=c) + t_counts = qrom_symb.t_complexity() + assert t_counts.t == 4 * (N * M - 2 + c) + assert t_counts + + def _assert_qrom_has_diagram(qrom: QROM, expected_diagram: str): gh = GateHelper(qrom) op = gh.operation @@ -231,19 +242,19 @@ def _assert_qrom_has_diagram(qrom: QROM, expected_diagram: str): def test_qrom_variable_spacing(): # Tests for variable spacing optimization applied from https://arxiv.org/abs/2007.07391 data = [1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8] # Figure 3a. - assert t_complexity(QROM.build(data)).t == (8 - 2) * 4 + assert t_complexity(QROM.build_from_data(data)).t == (8 - 2) * 4 data = [1, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5] # Figure 3b. - assert t_complexity(QROM.build(data)).t == (5 - 2) * 4 + assert t_complexity(QROM.build_from_data(data)).t == (5 - 2) * 4 data = [1, 2, 3, 4, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7] # Negative test: t count is not (g-2)*4 - assert t_complexity(QROM.build(data)).t == (8 - 2) * 4 + assert t_complexity(QROM.build_from_data(data)).t == (8 - 2) * 4 # Works as expected when multiple data arrays are to be loaded. data = [1, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5] # (a) Both data sequences are identical - assert t_complexity(QROM.build(data, data)).t == (5 - 2) * 4 + assert t_complexity(QROM.build_from_data(data, data)).t == (5 - 2) * 4 # (b) Both data sequences have identical structure, even though the elements are not same. - assert t_complexity(QROM.build(data, 2 * np.array(data))).t == (5 - 2) * 4 + assert t_complexity(QROM.build_from_data(data, 2 * np.array(data))).t == (5 - 2) * 4 # Works as expected when multidimensional input data is to be loaded - qrom = QROM.build( + qrom = QROM.build_from_data( np.array( [ [1, 1, 1, 1, 1, 1, 1, 1], @@ -265,7 +276,7 @@ def test_qrom_variable_spacing(): ''', ) # When inner loop range is not a power of 2, the inner segment tree cannot be skipped. - qrom = QROM.build( + qrom = QROM.build_from_data( np.array( [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2]], dtype=int, @@ -286,21 +297,21 @@ def test_qrom_variable_spacing(): ''', ) # No T-gates needed if all elements to load are identical. - assert t_complexity(QROM.build([3, 3, 3, 3])).t == 0 + assert t_complexity(QROM.build_from_data([3, 3, 3, 3])).t == 0 def test_qrom_wire_symbols(): - qrom = QROM.build([3, 3, 3, 3]) - assert_wire_symbols_match_expected(qrom, ['In', 'data_a']) + qrom = QROM.build_from_data([3, 3, 3, 3]) + assert_wire_symbols_match_expected(qrom, ['In', 'QROM_a']) - qrom = QROM.build([3, 3, 3, 3], [2, 2, 2, 2]) - assert_wire_symbols_match_expected(qrom, ['In', 'data_a', 'data_b']) + qrom = QROM.build_from_data([3, 3, 3, 3], [2, 2, 2, 2]) + assert_wire_symbols_match_expected(qrom, ['In', 'QROM_a', 'QROM_b']) - qrom = QROM.build([[3, 3], [3, 3]], [[2, 2], [2, 2]], [[1, 1], [2, 2]]) - assert_wire_symbols_match_expected(qrom, ['In_i', 'In_j', 'data_a', 'data_b', 'data_c']) + qrom = QROM.build_from_data([[3, 3], [3, 3]], [[2, 2], [2, 2]], [[1, 1], [2, 2]]) + assert_wire_symbols_match_expected(qrom, ['In_i', 'In_j', 'QROM_a', 'QROM_b', 'QROM_c']) - qrom = QROM.build(np.arange(27).reshape(3, 3, 3)) - assert_wire_symbols_match_expected(qrom, ['In_i', 'In_j', 'In_k', 'data_a']) + qrom = QROM.build_from_data(np.arange(27).reshape(3, 3, 3)) + assert_wire_symbols_match_expected(qrom, ['In_i', 'In_j', 'In_k', 'QROM_a']) @pytest.mark.slow @@ -387,7 +398,7 @@ def test_ndim_t_complexity(data, num_controls): def test_qrom_call_graph_matches_decomposition(num_controls): # Base case arr = np.arange(50) - qrom = QROM.build(arr, num_controls=num_controls) + qrom = QROM.build_from_data(arr, num_controls=num_controls) _, sigma_call = qrom.call_graph(generalizer=cirq_to_bloqs) _, sigma_dcmp = qrom.decompose_bloq().call_graph(generalizer=cirq_to_bloqs) assert sigma_call[TGate()] == sigma_dcmp[TGate()] @@ -396,7 +407,7 @@ def test_qrom_call_graph_matches_decomposition(num_controls): # Multiple Multi dimensional arrays arr_a = np.arange(64).reshape(8, 8) arr_b = 10 * np.arange(64).reshape(8, 8) - qrom = QROM.build(arr_a, arr_b, num_controls=num_controls) + qrom = QROM.build_from_data(arr_a, arr_b, num_controls=num_controls) _, sigma_call = qrom.call_graph(generalizer=cirq_to_bloqs) _, sigma_dcmp = qrom.decompose_bloq().call_graph(generalizer=cirq_to_bloqs) assert sigma_call[TGate()] == sigma_dcmp[TGate()] @@ -405,7 +416,7 @@ def test_qrom_call_graph_matches_decomposition(num_controls): # Variable QROM case. arr_a = np.array([1, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5]) arr_b = 10 * arr_a - qrom = QROM.build(arr_a, arr_b, num_controls=num_controls) + qrom = QROM.build_from_data(arr_a, arr_b, num_controls=num_controls) _, sigma_call = qrom.call_graph(generalizer=cirq_to_bloqs) _, sigma_dcmp = qrom.decompose_bloq().call_graph(generalizer=cirq_to_bloqs) assert sigma_call[TGate()] == sigma_dcmp[TGate()] diff --git a/qualtran/bloqs/state_preparation/state_preparation_alias_sampling_test.py b/qualtran/bloqs/state_preparation/state_preparation_alias_sampling_test.py index e41353b6a..015891a72 100644 --- a/qualtran/bloqs/state_preparation/state_preparation_alias_sampling_test.py +++ b/qualtran/bloqs/state_preparation/state_preparation_alias_sampling_test.py @@ -99,15 +99,15 @@ def test_state_preparation_via_coherent_alias_sampling_diagram(): │ │ │ sigma_mu2: ─────────H────────────┼────────In(y)───────┼────── │ │ │ -alt0: ───────────────────────────QROM_0───┼───────────×(x)─── +alt0: ───────────────────────────QROM_a───┼───────────×(x)─── │ │ │ -alt1: ───────────────────────────QROM_0───┼───────────×(x)─── +alt1: ───────────────────────────QROM_a───┼───────────×(x)─── │ │ │ -keep0: ──────────────────────────QROM_1───In(x)───────┼────── +keep0: ──────────────────────────QROM_b───In(x)───────┼────── │ │ │ -keep1: ──────────────────────────QROM_1───In(x)───────┼────── +keep1: ──────────────────────────QROM_b───In(x)───────┼────── │ │ │ -keep2: ──────────────────────────QROM_1───In(x)───────┼────── +keep2: ──────────────────────────QROM_b───In(x)───────┼────── │ │ less_than_equal: ─────────────────────────⨁(x <= y)───@────── ''', diff --git a/qualtran/resource_counting/classify_bloqs_test.py b/qualtran/resource_counting/classify_bloqs_test.py index 187ab5404..abae3895d 100644 --- a/qualtran/resource_counting/classify_bloqs_test.py +++ b/qualtran/resource_counting/classify_bloqs_test.py @@ -57,7 +57,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: (((CSwap(10), 42),), 'swaps'), (((HammingWeightPhasing(10, 1.11), 11),), 'rotations'), (((Add(QInt(8)), 4),), 'arithmetic'), - (((QROM.build([4, 10, 11, 34]), 8),), 'data_loading'), + (((QROM.build_from_data([4, 10, 11, 34]), 8),), 'data_loading'), (((And(), 4),), 'multi_control_pauli'), # https://github.com/python/mypy/issues/5313 (((Reflection((3, 3, 2), (0, 0, 1)), 100),), 'reflection'), # type: ignore[arg-type] @@ -76,7 +76,7 @@ def test_default_classification(bloq_count, classification): (CSwap(10), 'swaps'), (HammingWeightPhasing(10, 1.11), 'rotations'), (Add(QInt(8)), 'arithmetic'), - (QROM.build([4, 10, 11, 34]), 'data_loading'), + (QROM.build_from_data([4, 10, 11, 34]), 'data_loading'), (And(), 'multi_control_pauli'), # https://github.com/python/mypy/issues/5313 (Reflection((3, 3, 2), (0, 0, 1)), 'reflection'), # type: ignore[arg-type] diff --git a/qualtran/symbolics/__init__.py b/qualtran/symbolics/__init__.py index 4f29a3a3e..50268b77b 100644 --- a/qualtran/symbolics/__init__.py +++ b/qualtran/symbolics/__init__.py @@ -20,8 +20,10 @@ floor, log2, pi, + prod, sabs, sconj, + shape, slen, smax, smin, diff --git a/qualtran/symbolics/math_funcs.py b/qualtran/symbolics/math_funcs.py index 1991a0a78..cd496c1c1 100644 --- a/qualtran/symbolics/math_funcs.py +++ b/qualtran/symbolics/math_funcs.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Sized, Union +from typing import cast, overload, Sized, Tuple, Union import numpy as np import sympy @@ -75,6 +75,13 @@ def smin(*args): return min(*args) +def prod(*args: SymbolicInt) -> SymbolicInt: + ret: SymbolicInt = 1 + for arg in args: + ret = ret * arg + return ret + + def acos(x: SymbolicFloat) -> SymbolicFloat: if not isinstance(x, sympy.Basic): return np.arccos(x) @@ -90,3 +97,17 @@ def slen(x: Union[Sized, Shaped]) -> SymbolicInt: if isinstance(x, Shaped): return x.shape[0] return len(x) + + +@overload +def shape(x: np.ndarray) -> Tuple[int, ...]: + ... + + +@overload +def shape(x: Shaped) -> Tuple[SymbolicInt, ...]: + ... + + +def shape(x: Union[np.ndarray, Shaped]): + return x.shape