Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow parameter shifting on all parametric gates #31

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion horqrux/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def adjoint_expectation_single_observable_bwd(
out_state = apply_gate(out_state, gate, values, OperationType.DAGGER)
if isinstance(gate, Parametric):
mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN)
grads[gate.param] = tangent * 2 * inner(mu, projected_state).real
grads[gate.param_name] = tangent * 2 * inner(mu, projected_state).real
projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER)
return (None, None, None, grads)

Expand Down
8 changes: 0 additions & 8 deletions horqrux/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import jax
import jax.numpy as jnp
from jax import Array
from jax.experimental import checkify

from horqrux.adjoint import adjoint_expectation
from horqrux.apply import apply_gate
Expand Down Expand Up @@ -96,13 +95,6 @@ def expectation(
elif diff_mode == DiffMode.ADJOINT:
return adjoint_expectation(state, gates, observables, values)
elif diff_mode == DiffMode.GPSR:
checkify.check(
forward_mode == ForwardMode.SHOTS, "Finite shots and GPSR must be used together"
)
checkify.check(
type(n_shots) is int,
"Number of shots must be an integer for finite shots.",
)
# Type checking is disabled because mypy doesn't parse checkify.check.
# type: ignore
return finite_shots_fwd(state, gates, observables, values, n_shots=n_shots, key=key)
4 changes: 2 additions & 2 deletions horqrux/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __call__(self, param_values: Array) -> Array:

@property
def param_names(self) -> list[str]:
return [str(op.param) for op in self.ansatz if isinstance(op, Parametric)]
return [str(op.param_name) for op in self.ansatz if isinstance(op, Parametric)]

@property
def n_vparams(self) -> int:
Expand All @@ -61,7 +61,7 @@ def hea(n_qubits: int, n_layers: int, rot_fns: list[Callable] = [RX, RY, RX]) ->
fn(str(uuid4()), qubit)
for fn, qubit in zip(rot_fns, [i for _ in range(len(rot_fns))])
]
param_names += [op.param for op in ops]
param_names += [op.param_name for op in ops]
ops += [NOT((i + 1) % n_qubits, i % n_qubits) for i in range(n_qubits)] # type: ignore[arg-type]
gates += ops

Expand Down
78 changes: 51 additions & 27 deletions horqrux/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,49 @@ class Parametric(Primitive):
generator_name: str
target: QubitSupport
control: QubitSupport
param: str | float = ""
param_name: str | None = None
param_val: float = 0.0
shift: float = 0.0

def __post_init__(self) -> None:
super().__post_init__()

def parse_dict(values: dict[str, float] = dict()) -> float:
return values[self.param] # type: ignore[index]
def parse_dict(self: Parametric, values: dict[str, float] = dict()) -> float:
return values[self.param_name] + self.shift # type: ignore[index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to remove these type ignores ?


def parse_val(values: dict[str, float] = dict()) -> float:
return self.param # type: ignore[return-value]
def parse_val(self: Parametric, values: dict[str, float] = dict()) -> float:
return self.param_val + self.shift # type: ignore[return-value]

self.parse_values = parse_dict if isinstance(self.param, str) else parse_val
self.parse_values = parse_val if self.param_name is None else parse_dict

def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: # type: ignore[override]
children = ()
aux_data = (
self.generator_name,
self.target[0],
self.control[0],
self.param,
)
def tree_flatten( # type: ignore[override]
self,
) -> tuple[tuple[float, float], tuple[str, tuple, tuple, str | None]]:
children = (self.param_val, self.shift)
aux_data = (self.generator_name, self.target[0], self.control[0], self.param_name)
return (children, aux_data)

def __iter__(self) -> Iterable:
return iter((self.generator_name, self.target, self.control, self.param))
return iter(
(
self.generator_name,
self.target,
self.control,
self.param_name,
self.param_val,
self.shift,
)
)

@classmethod
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
return cls(*children, *aux_data)
return cls(*aux_data, *children)

def unitary(self, values: dict[str, float] = dict()) -> Array:
return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values))
return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(self, values))

def jacobian(self, values: dict[str, float] = dict()) -> Array:
return _jacobian(OPERATIONS_DICT[self.generator_name], self.parse_values(values))
return _jacobian(OPERATIONS_DICT[self.generator_name], self.parse_values(self, values))

@property
def name(self) -> str:
Expand All @@ -70,9 +78,17 @@ def name(self) -> str:

def __repr__(self) -> str:
return (
self.name + f"(target={self.target[0]}, control={self.control[0]}, param={self.param})"
self.name
+ f"(target={self.target[0]},"
+ f"control={self.control[0]},"
+ f"param_name={self.param_name},"
+ f"param_val={self.param_val},"
+ f"shift={self.shift})"
)

def set_shift(self, shift: float) -> None:
self.shift = shift


def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
"""RX gate.
Expand All @@ -85,7 +101,10 @@ def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,
Returns:
Parametric: A Parametric gate object.
"""
return Parametric("X", target, control, param)

if isinstance(param, str):
return Parametric("X", target, control, param_name=param)
return Parametric("X", target, control, param_val=param)


def RY(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
Expand All @@ -99,7 +118,9 @@ def RY(param: float | str, target: TargetQubits, control: ControlQubits = (None,
Returns:
Parametric: A Parametric gate object.
"""
return Parametric("Y", target, control, param)
if isinstance(param, str):
return Parametric("Y", target, control, param_name=param)
return Parametric("Y", target, control, param_val=param)


def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
Expand All @@ -113,18 +134,20 @@ def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None,
Returns:
Parametric: A Parametric gate object.
"""
return Parametric("Z", target, control, param)
if isinstance(param, str):
return Parametric("Z", target, control, param_name=param)
return Parametric("Z", target, control, param_val=param)


class _PHASE(Parametric):
def unitary(self, values: dict[str, float] = dict()) -> Array:
u = jnp.eye(2, 2, dtype=jnp.complex128)
u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(values)))
u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(self, values)))
return u

def jacobian(self, values: dict[str, float] = dict()) -> Array:
jac = jnp.zeros((2, 2), dtype=jnp.complex128)
jac = jac.at[(1, 1)].set(1j * jnp.exp(1.0j * self.parse_values(values)))
jac = jac.at[(1, 1)].set(1j * jnp.exp(1.0j * self.parse_values(self, values)))
return jac

@property
Expand All @@ -133,7 +156,7 @@ def name(self) -> str:
return "C" + base_name if is_controlled(self.control) else base_name


def PHASE(param: float, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
def PHASE(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
"""Phase gate.

Arguments:
Expand All @@ -144,5 +167,6 @@ def PHASE(param: float, target: TargetQubits, control: ControlQubits = (None,))
Returns:
Parametric: A Parametric gate object.
"""

return _PHASE("I", target, control, param)
if isinstance(param, str):
return _PHASE("I", target, control, param_name=param)
return _PHASE("I", target, control, param_val=param)
62 changes: 34 additions & 28 deletions horqrux/shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import jax
import jax.numpy as jnp
from jax import Array, random
from jax.experimental import checkify

from horqrux.apply import apply_gate
from horqrux.parametric import Parametric
from horqrux.primitive import GateSequence, Primitive
from horqrux.utils import none_like


def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array:
Expand All @@ -21,10 +20,6 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array:

LIMITATION: currently only works for observables which are not controlled.
"""
checkify.check(
observable.control == observable.parse_idx(none_like(observable.target)),
"Controlled gates cannot be promoted from observables to operations on the whole state vector",
)
unitary = observable.unitary()
target = observable.target[0][0]
identity = jnp.eye(2, dtype=unitary.dtype)
Expand All @@ -33,7 +28,7 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array:
return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0])


@partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5))
@partial(jax.custom_jvp, nondiff_argnums=(0, 2, 4, 5))
def finite_shots_fwd(
state: Array,
gates: GateSequence,
Expand Down Expand Up @@ -79,10 +74,6 @@ def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]:
for mat in eigs_copy:
inv = jnp.linalg.inv(mat[1])
P = (inv @ eigenvector_matrix).real > 0.5
checkify.check(
validate_permutation_matrix(P),
"Did not calculate valid permutation matrix",
)
eigenvalues.append(mat[0] @ P)
return eigenvector_matrix, jnp.stack(eigenvalues, axis=1)

Expand All @@ -97,33 +88,48 @@ def validate_permutation_matrix(P: Array) -> Array:
@finite_shots_fwd.defjvp
def finite_shots_jvp(
state: Array,
gates: GateSequence,
observable: Primitive,
n_shots: int,
key: Array,
primals: tuple[dict[str, float]],
tangents: tuple[dict[str, float]],
primals: tuple[list[Primitive], dict[str, float]],
tangents: tuple[list[Primitive], dict[str, float]],
) -> Array:
values = primals[0]
tangent_dict = tangents[0]
gates, values = primals
gates_tangent, values_tangent = tangents

# TODO: compute spectral gap through the generator which is associated with
# a param name.
spectral_gap = 2.0
shift = jnp.pi / 2

def jvp_component(param_name: str, key: Array) -> Array:
fwd = finite_shots_fwd(state, gates, observable, values, n_shots, key)

keys = random.split(key, len(gates))
jvp = jnp.zeros_like(fwd)
zero = jnp.zeros_like(fwd)

def jvp_component(index: int) -> Array:
gates, values = primals
gates_tangent, values_tangent = tangents
shift_gate = gates[index]
gate_tangent = gates_tangent[index]
if not isinstance(shift_gate, Parametric) or not isinstance(gate_tangent, Parametric):
return zero
if shift_gate.param_name is None:
tangent = gate_tangent.param_val
else:
tangent = values_tangent[shift_gate.param_name]
if not isinstance(tangent, jax.Array):
return zero
key = keys[index]
up_key, down_key = random.split(key)
up_val = values.copy()
up_val[param_name] = up_val[param_name] + shift
f_up = finite_shots_fwd(state, gates, observable, up_val, n_shots, up_key)
down_val = values.copy()
down_val[param_name] = down_val[param_name] - shift
f_down = finite_shots_fwd(state, gates, observable, down_val, n_shots, down_key)
original_shift = shift_gate.shift
shift_gate.set_shift(original_shift + shift)
f_up = finite_shots_fwd(state, gates, observable, values, n_shots, up_key)
shift_gate.set_shift(original_shift - shift)
f_down = finite_shots_fwd(state, gates, observable, values, n_shots, down_key)
shift_gate.set_shift(original_shift)
grad = spectral_gap * (f_up - f_down) / (4.0 * jnp.sin(spectral_gap * shift / 2.0))
return grad * tangent_dict[param_name]
return grad * tangent

params_with_keys = zip(values.keys(), random.split(key, len(values)))
fwd = finite_shots_fwd(state, gates, observable, values, n_shots, key)
jvp = sum(jvp_component(param, key) for param, key in params_with_keys)
return fwd, jvp.reshape(fwd.shape)
return fwd, sum(jvp_component(i) for i, _ in enumerate(gates))
31 changes: 17 additions & 14 deletions tests/test_shots.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import functools

import jax
import jax.numpy as jnp

Expand All @@ -13,28 +15,29 @@


def test_shots() -> None:
ops = [RX("theta", 0)]
observables = [Z(0), Z(1)]
state = random_state(N_QUBITS)
x = jnp.pi * 0.5
x = jnp.pi * 0.123
y = jnp.pi * 0.456

def exact(x):
@functools.partial(jax.jit, static_argnums=2)
def expect(x, y, method):
values = {"theta": x}
ops = [RX("theta", 0), RX(0.2, 0), RX(y, 1), RX("theta", 1)]
if method == "shots":
return expectation(state, ops, observables, values, "gpsr", "shots", n_shots=N_SHOTS)
return expectation(state, ops, observables, values, "ad")

def shots(x):
values = {"theta": x}
return expectation(state, ops, observables, values, "gpsr", "shots", n_shots=N_SHOTS)

exp_exact = exact(x)
exp_shots = shots(x)
exp_exact = expect(x, y, "exact")
exp_shots = expect(x, y, "shots")

assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL)

d_exact = jax.grad(lambda x: exact(x).sum())
d_shots = jax.grad(lambda x: shots(x).sum())
d_expect = jax.jit(
jax.grad(lambda x, y, z: expect(x, y, z).sum(), argnums=[0, 1]), static_argnums=2
)

grad_backprop = d_exact(x)
grad_shots = d_shots(x)
grad_backprop = jnp.stack(d_expect(x, y, "exact"))
grad_shots = jnp.stack(d_expect(x, y, "shots"))

assert jnp.isclose(grad_backprop, grad_shots, atol=SHOTS_ATOL)
assert jnp.allclose(grad_backprop, grad_shots, atol=SHOTS_ATOL)
Loading