diff --git a/horqrux/adjoint.py b/horqrux/adjoint.py index f8f076f..3dc1c83 100644 --- a/horqrux/adjoint.py +++ b/horqrux/adjoint.py @@ -66,7 +66,7 @@ def adjoint_expectation_single_observable_bwd( out_state = apply_gate(out_state, gate, values, OperationType.DAGGER) if isinstance(gate, Parametric): mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN) - grads[gate.param] = tangent * 2 * inner(mu, projected_state).real + grads[gate.param_name] = tangent * 2 * inner(mu, projected_state).real projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER) return (None, None, None, grads) diff --git a/horqrux/api.py b/horqrux/api.py index 2eeb464..02cf521 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -6,7 +6,6 @@ import jax import jax.numpy as jnp from jax import Array -from jax.experimental import checkify from horqrux.adjoint import adjoint_expectation from horqrux.apply import apply_gate @@ -96,13 +95,6 @@ def expectation( elif diff_mode == DiffMode.ADJOINT: return adjoint_expectation(state, gates, observables, values) elif diff_mode == DiffMode.GPSR: - checkify.check( - forward_mode == ForwardMode.SHOTS, "Finite shots and GPSR must be used together" - ) - checkify.check( - type(n_shots) is int, - "Number of shots must be an integer for finite shots.", - ) # Type checking is disabled because mypy doesn't parse checkify.check. # type: ignore return finite_shots_fwd(state, gates, observables, values, n_shots=n_shots, key=key) diff --git a/horqrux/circuit.py b/horqrux/circuit.py index ad47d2a..ebca5c5 100644 --- a/horqrux/circuit.py +++ b/horqrux/circuit.py @@ -34,7 +34,7 @@ def __call__(self, param_values: Array) -> Array: @property def param_names(self) -> list[str]: - return [str(op.param) for op in self.ansatz if isinstance(op, Parametric)] + return [str(op.param_name) for op in self.ansatz if isinstance(op, Parametric)] @property def n_vparams(self) -> int: @@ -61,7 +61,7 @@ def hea(n_qubits: int, n_layers: int, rot_fns: list[Callable] = [RX, RY, RX]) -> fn(str(uuid4()), qubit) for fn, qubit in zip(rot_fns, [i for _ in range(len(rot_fns))]) ] - param_names += [op.param for op in ops] + param_names += [op.param_name for op in ops] ops += [NOT((i + 1) % n_qubits, i % n_qubits) for i in range(n_qubits)] # type: ignore[arg-type] gates += ops diff --git a/horqrux/parametric.py b/horqrux/parametric.py index bd5d488..591bd82 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -27,41 +27,49 @@ class Parametric(Primitive): generator_name: str target: QubitSupport control: QubitSupport - param: str | float = "" + param_name: str | None = None + param_val: float = 0.0 + shift: float = 0.0 def __post_init__(self) -> None: super().__post_init__() - def parse_dict(values: dict[str, float] = dict()) -> float: - return values[self.param] # type: ignore[index] + def parse_dict(self: Parametric, values: dict[str, float] = dict()) -> float: + return values[self.param_name] + self.shift # type: ignore[index] - def parse_val(values: dict[str, float] = dict()) -> float: - return self.param # type: ignore[return-value] + def parse_val(self: Parametric, values: dict[str, float] = dict()) -> float: + return self.param_val + self.shift # type: ignore[return-value] - self.parse_values = parse_dict if isinstance(self.param, str) else parse_val + self.parse_values = parse_val if self.param_name is None else parse_dict - def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: # type: ignore[override] - children = () - aux_data = ( - self.generator_name, - self.target[0], - self.control[0], - self.param, - ) + def tree_flatten( # type: ignore[override] + self, + ) -> tuple[tuple[float, float], tuple[str, tuple, tuple, str | None]]: + children = (self.param_val, self.shift) + aux_data = (self.generator_name, self.target[0], self.control[0], self.param_name) return (children, aux_data) def __iter__(self) -> Iterable: - return iter((self.generator_name, self.target, self.control, self.param)) + return iter( + ( + self.generator_name, + self.target, + self.control, + self.param_name, + self.param_val, + self.shift, + ) + ) @classmethod def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: - return cls(*children, *aux_data) + return cls(*aux_data, *children) def unitary(self, values: dict[str, float] = dict()) -> Array: - return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values)) + return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(self, values)) def jacobian(self, values: dict[str, float] = dict()) -> Array: - return _jacobian(OPERATIONS_DICT[self.generator_name], self.parse_values(values)) + return _jacobian(OPERATIONS_DICT[self.generator_name], self.parse_values(self, values)) @property def name(self) -> str: @@ -70,9 +78,17 @@ def name(self) -> str: def __repr__(self) -> str: return ( - self.name + f"(target={self.target[0]}, control={self.control[0]}, param={self.param})" + self.name + + f"(target={self.target[0]}," + + f"control={self.control[0]}," + + f"param_name={self.param_name}," + + f"param_val={self.param_val}," + + f"shift={self.shift})" ) + def set_shift(self, shift: float) -> None: + self.shift = shift + def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: """RX gate. @@ -85,7 +101,10 @@ def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None, Returns: Parametric: A Parametric gate object. """ - return Parametric("X", target, control, param) + + if isinstance(param, str): + return Parametric("X", target, control, param_name=param) + return Parametric("X", target, control, param_val=param) def RY(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: @@ -99,7 +118,9 @@ def RY(param: float | str, target: TargetQubits, control: ControlQubits = (None, Returns: Parametric: A Parametric gate object. """ - return Parametric("Y", target, control, param) + if isinstance(param, str): + return Parametric("Y", target, control, param_name=param) + return Parametric("Y", target, control, param_val=param) def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: @@ -113,18 +134,20 @@ def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None, Returns: Parametric: A Parametric gate object. """ - return Parametric("Z", target, control, param) + if isinstance(param, str): + return Parametric("Z", target, control, param_name=param) + return Parametric("Z", target, control, param_val=param) class _PHASE(Parametric): def unitary(self, values: dict[str, float] = dict()) -> Array: u = jnp.eye(2, 2, dtype=jnp.complex128) - u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(values))) + u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(self, values))) return u def jacobian(self, values: dict[str, float] = dict()) -> Array: jac = jnp.zeros((2, 2), dtype=jnp.complex128) - jac = jac.at[(1, 1)].set(1j * jnp.exp(1.0j * self.parse_values(values))) + jac = jac.at[(1, 1)].set(1j * jnp.exp(1.0j * self.parse_values(self, values))) return jac @property @@ -133,7 +156,7 @@ def name(self) -> str: return "C" + base_name if is_controlled(self.control) else base_name -def PHASE(param: float, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: +def PHASE(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: """Phase gate. Arguments: @@ -144,5 +167,6 @@ def PHASE(param: float, target: TargetQubits, control: ControlQubits = (None,)) Returns: Parametric: A Parametric gate object. """ - - return _PHASE("I", target, control, param) + if isinstance(param, str): + return _PHASE("I", target, control, param_name=param) + return _PHASE("I", target, control, param_val=param) diff --git a/horqrux/shots.py b/horqrux/shots.py index 4383100..ef22300 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -6,11 +6,10 @@ import jax import jax.numpy as jnp from jax import Array, random -from jax.experimental import checkify from horqrux.apply import apply_gate +from horqrux.parametric import Parametric from horqrux.primitive import GateSequence, Primitive -from horqrux.utils import none_like def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: @@ -21,10 +20,6 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: LIMITATION: currently only works for observables which are not controlled. """ - checkify.check( - observable.control == observable.parse_idx(none_like(observable.target)), - "Controlled gates cannot be promoted from observables to operations on the whole state vector", - ) unitary = observable.unitary() target = observable.target[0][0] identity = jnp.eye(2, dtype=unitary.dtype) @@ -33,7 +28,7 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0]) -@partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5)) +@partial(jax.custom_jvp, nondiff_argnums=(0, 2, 4, 5)) def finite_shots_fwd( state: Array, gates: GateSequence, @@ -79,10 +74,6 @@ def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]: for mat in eigs_copy: inv = jnp.linalg.inv(mat[1]) P = (inv @ eigenvector_matrix).real > 0.5 - checkify.check( - validate_permutation_matrix(P), - "Did not calculate valid permutation matrix", - ) eigenvalues.append(mat[0] @ P) return eigenvector_matrix, jnp.stack(eigenvalues, axis=1) @@ -97,33 +88,48 @@ def validate_permutation_matrix(P: Array) -> Array: @finite_shots_fwd.defjvp def finite_shots_jvp( state: Array, - gates: GateSequence, observable: Primitive, n_shots: int, key: Array, - primals: tuple[dict[str, float]], - tangents: tuple[dict[str, float]], + primals: tuple[list[Primitive], dict[str, float]], + tangents: tuple[list[Primitive], dict[str, float]], ) -> Array: - values = primals[0] - tangent_dict = tangents[0] + gates, values = primals + gates_tangent, values_tangent = tangents # TODO: compute spectral gap through the generator which is associated with # a param name. spectral_gap = 2.0 shift = jnp.pi / 2 - def jvp_component(param_name: str, key: Array) -> Array: + fwd = finite_shots_fwd(state, gates, observable, values, n_shots, key) + + keys = random.split(key, len(gates)) + jvp = jnp.zeros_like(fwd) + zero = jnp.zeros_like(fwd) + + def jvp_component(index: int) -> Array: + gates, values = primals + gates_tangent, values_tangent = tangents + shift_gate = gates[index] + gate_tangent = gates_tangent[index] + if not isinstance(shift_gate, Parametric) or not isinstance(gate_tangent, Parametric): + return zero + if shift_gate.param_name is None: + tangent = gate_tangent.param_val + else: + tangent = values_tangent[shift_gate.param_name] + if not isinstance(tangent, jax.Array): + return zero + key = keys[index] up_key, down_key = random.split(key) - up_val = values.copy() - up_val[param_name] = up_val[param_name] + shift - f_up = finite_shots_fwd(state, gates, observable, up_val, n_shots, up_key) - down_val = values.copy() - down_val[param_name] = down_val[param_name] - shift - f_down = finite_shots_fwd(state, gates, observable, down_val, n_shots, down_key) + original_shift = shift_gate.shift + shift_gate.set_shift(original_shift + shift) + f_up = finite_shots_fwd(state, gates, observable, values, n_shots, up_key) + shift_gate.set_shift(original_shift - shift) + f_down = finite_shots_fwd(state, gates, observable, values, n_shots, down_key) + shift_gate.set_shift(original_shift) grad = spectral_gap * (f_up - f_down) / (4.0 * jnp.sin(spectral_gap * shift / 2.0)) - return grad * tangent_dict[param_name] + return grad * tangent - params_with_keys = zip(values.keys(), random.split(key, len(values))) - fwd = finite_shots_fwd(state, gates, observable, values, n_shots, key) - jvp = sum(jvp_component(param, key) for param, key in params_with_keys) - return fwd, jvp.reshape(fwd.shape) + return fwd, sum(jvp_component(i) for i, _ in enumerate(gates)) diff --git a/tests/test_shots.py b/tests/test_shots.py index c98062d..7a589bf 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -1,5 +1,7 @@ from __future__ import annotations +import functools + import jax import jax.numpy as jnp @@ -13,28 +15,29 @@ def test_shots() -> None: - ops = [RX("theta", 0)] observables = [Z(0), Z(1)] state = random_state(N_QUBITS) - x = jnp.pi * 0.5 + x = jnp.pi * 0.123 + y = jnp.pi * 0.456 - def exact(x): + @functools.partial(jax.jit, static_argnums=2) + def expect(x, y, method): values = {"theta": x} + ops = [RX("theta", 0), RX(0.2, 0), RX(y, 1), RX("theta", 1)] + if method == "shots": + return expectation(state, ops, observables, values, "gpsr", "shots", n_shots=N_SHOTS) return expectation(state, ops, observables, values, "ad") - def shots(x): - values = {"theta": x} - return expectation(state, ops, observables, values, "gpsr", "shots", n_shots=N_SHOTS) - - exp_exact = exact(x) - exp_shots = shots(x) + exp_exact = expect(x, y, "exact") + exp_shots = expect(x, y, "shots") assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) - d_exact = jax.grad(lambda x: exact(x).sum()) - d_shots = jax.grad(lambda x: shots(x).sum()) + d_expect = jax.jit( + jax.grad(lambda x, y, z: expect(x, y, z).sum(), argnums=[0, 1]), static_argnums=2 + ) - grad_backprop = d_exact(x) - grad_shots = d_shots(x) + grad_backprop = jnp.stack(d_expect(x, y, "exact")) + grad_shots = jnp.stack(d_expect(x, y, "shots")) - assert jnp.isclose(grad_backprop, grad_shots, atol=SHOTS_ATOL) + assert jnp.allclose(grad_backprop, grad_shots, atol=SHOTS_ATOL)