Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add custom vjp with adjoint differentiation #10

Merged
merged 9 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ new_state = apply_gate(state, RX(param_value, target_qubit, control_qubit))

We can now build a fully differentiable variational circuit by simply defining a sequence of gates
and a set of initial parameter values we want to optimize.
Lets fit a function using a simple circuit class wrapper.
Horqrux provides an implementation of [adjoint differentiation](https://arxiv.org/abs/2009.02823),
which we can use to fit a function using a simple circuit class wrapper.

```python exec="on" source="material-block" html="1"
from __future__ import annotations
Expand All @@ -85,6 +86,7 @@ from operator import add
from typing import Any, Callable
from uuid import uuid4

from horqrux.adjoint import adjoint_expectation
from horqrux.abstract import Operator
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate, overlap

Expand Down Expand Up @@ -127,8 +129,7 @@ class Circuit:
def forward(self, param_values: Array, x: Array) -> Array:
state = zero_state(self.n_qubits)
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
state = apply_gate(state, self.feature_map + self.ansatz, {**param_dict, **{'phi': x}})
return overlap(state, apply_gate(state, self.observable))
return adjoint_expectation(state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})

def __call__(self, param_values: Array, x: Array) -> Array:
return self.forward(param_values, x)
Expand Down
15 changes: 11 additions & 4 deletions horqrux/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __iter__(self) -> Iterable:

def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits]]:
children = ()
aux_data = (self.generator_name, self.target, self.control)
aux_data = (self.generator_name, self.target[0], self.control[0])
return (children, aux_data)

@classmethod
Expand Down Expand Up @@ -101,13 +101,20 @@ def parse_val(values: dict[str, float] = dict()) -> float:
def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: # type: ignore[override]
children = ()
aux_data = (
self.name,
self.target,
self.control,
self.generator_name,
self.target[0],
self.control[0],
self.param,
)
return (children, aux_data)

def __iter__(self) -> Iterable:
return iter((self.generator_name, self.target, self.control, self.param))

@classmethod
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
return cls(*children, *aux_data)

def unitary(self, values: dict[str, float] = dict()) -> Array:
return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values))

Expand Down
49 changes: 49 additions & 0 deletions horqrux/adjoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from typing import Tuple

from jax import Array, custom_vjp

from horqrux.abstract import Operator, Parametric
from horqrux.apply import apply_gate
from horqrux.utils import OperationType, overlap


def expectation(
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
) -> Array:
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return overlap(out_state, projected_state)


@custom_vjp
def adjoint_expectation(
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
) -> Array:
return expectation(state, gates, observable, values)


def adjoint_expectation_fwd(
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
) -> Tuple[Array, Tuple[Array, Array, list[Operator], dict[str, float]]]:
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return overlap(out_state, projected_state), (out_state, projected_state, gates, values)


def adjoint_expectation_bwd(
res: Tuple[Array, Array, list[Operator], dict[str, float]], tangent: Array
) -> tuple:
out_state, projected_state, gates, values = res
grads = {}
for gate in gates[::-1]:
out_state = apply_gate(out_state, gate, values, OperationType.DAGGER)
if isinstance(gate, Parametric):
mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN)
grads[gate.param] = tangent * 2 * overlap(mu, projected_state)
projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER)
return (None, None, None, grads)


adjoint_expectation.defvjp(adjoint_expectation_fwd, adjoint_expectation_bwd)
44 changes: 25 additions & 19 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,67 +10,73 @@

from horqrux.abstract import Operator

from .utils import State, _controlled, is_controlled
from .utils import OperationType, State, _controlled, is_controlled


def apply_operator(
state: State,
unitary: Array,
operator: Array,
target: Tuple[int, ...],
control: Tuple[int | None, ...],
) -> State:
"""Applies a unitary, i.e. a single array of shape [2, 2, ...], on a given state
"""Applies an operator, i.e. a single array of shape [2, 2, ...], on a given state
of shape [2 for _ in range(n_qubits)] for a given set of target and control qubits.
In case of control qubits, the 'unitary' array will be embedded into a controlled array.
In case of a controlled operation, the 'operator' array will be embedded into a controlled array.

Since dimension 'i' in 'state' corresponds to all amplitudes which are affected by qubit 'i',
target and control qubits correspond to dimensions to contract 'unitary' over.
Contraction over qubit 'i' means applying the 'dot' operation between 'unitary' and dimension 'i'
Since dimension 'i' in 'state' corresponds to all amplitudes where qubit 'i' is 1,
target and control qubits represent the dimensions over which to contract the 'operator'.
Contraction means applying the 'dot' operation between the operator array and dimension 'i'
of 'state, resulting in a new state where the result of the 'dot' operation has been moved to
dimension 'i' of 'state'. To restore the former order of dimensions, the affected dimensions
are moved to their original positions and the state is returned.

Arguments:
state: State to operate on.
unitary: Array to contract over 'state'.
operator: Array to contract over 'state'.
target: Tuple of target qubits on which to apply the 'operator' to.
control: Tuple of control qubits.

Returns:
State after applying 'unitary'.
State after applying 'operator'.
"""
state_dims: Tuple[int, ...] = target
if is_controlled(control):
unitary = _controlled(unitary, len(control))
operator = _controlled(operator, len(control))
state_dims = (*control, *target) # type: ignore[arg-type]
n_qubits = int(np.log2(unitary.size))
unitary = unitary.reshape(tuple(2 for _ in np.arange(n_qubits)))
op_dims = tuple(np.arange(unitary.ndim // 2, unitary.ndim, dtype=int))
state = jnp.tensordot(a=unitary, b=state, axes=(op_dims, state_dims))
n_qubits = int(np.log2(operator.size))
operator = operator.reshape(tuple(2 for _ in np.arange(n_qubits)))
op_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int))
state = jnp.tensordot(a=operator, b=state, axes=(op_dims, state_dims))
new_state_dims = tuple(i for i in range(len(state_dims)))
return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims)


def apply_gate(
state: State, gate: Operator | Iterable[Operator], values: dict[str, float] = dict()
state: State,
gate: Operator | Iterable[Operator],
values: dict[str, float] = dict(),
op_type: OperationType = OperationType.UNITARY,
) -> State:
"""Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state.
Arguments:
state: State to operate on.
gate: Gate(s) to apply.
values: A dictionary with parameter values.
op_type: The type of operation to perform: Unitary, Dagger or Jacobian.

Returns:
State after applying 'gate'.
"""
unitary: Tuple[Array, ...]
operator: Tuple[Array, ...]
if isinstance(gate, Operator):
unitary, target, control = (gate.unitary(values),), gate.target, gate.control
operator_fn = getattr(gate, op_type)
operator, target, control = (operator_fn(values),), gate.target, gate.control
else:
unitary = tuple(g.unitary(values) for g in gate)
operator = tuple(getattr(g, op_type)(values) for g in gate)
target = reduce(add, [g.target for g in gate])
control = reduce(add, [g.control for g in gate])
return reduce(
lambda state, gate: apply_operator(state, *gate),
zip(unitary, target, control),
zip(operator, target, control),
state,
)
20 changes: 10 additions & 10 deletions horqrux/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def I(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""Identity / I gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
By providing tuple of ints to 'control', it turns into a controlled gate.

Example usage: I(1) applies I to qubit 1.
Example usage: I(1) represents the instruction to apply I to qubit 1.

Args:
target: Tuple of ints describing the qubits to apply to.
Expand All @@ -26,8 +26,8 @@ def X(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""X gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
By providing tuple of ints to 'control', it turns into a controlled gate.

Example usage: X(1) applies X to qubit 1.
Example usage controlled: X(1, 0) applies CX / CNOT to qubit 1 with controlled qubit 0.
Example usage: X(1) represents the instruction to apply X to qubit 1.
Example usage controlled: X(1, 0) represents the instruction to apply CX / CNOT to qubit 1 with controlled qubit 0.

Args:
target: Tuple of ints describing the qubits to apply to.
Expand All @@ -46,8 +46,8 @@ def Y(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""Y gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
By providing tuple of ints to 'control', it turns into a controlled gate.

Example usage: Y(1) applies X to qubit 1.
Example usage controlled: Y(1, 0) applies CY to qubit 1 with controlled qubit 0.
Example usage: Y(1) represents the instruction to apply X to qubit 1.
Example usage controlled: Y(1, 0) represents the instruction to apply CY to qubit 1 with controlled qubit 0.

Args:
target: Tuple of ints describing the qubits to apply to.
Expand All @@ -63,8 +63,8 @@ def Z(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""Z gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
By providing tuple of ints to 'control', it turns into a controlled gate.

Example usage: Z(1) applies Z to qubit 1.
Example usage controlled: Z(1, 0) applies CZ to qubit 1 with controlled qubit 0.
Example usage: Z(1) represents the instruction to apply Z to qubit 1.
Example usage controlled: Z(1, 0) represents the instruction to apply CZ to qubit 1 with controlled qubit 0.

Args:
target: Tuple of ints describing the qubits to apply to.
Expand All @@ -80,7 +80,7 @@ def H(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""H/ Hadamard gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
By providing tuple of ints to 'control', it turns into a controlled gate.

Example usage: H(1) applies Hadamard to qubit 1.
Example usage: H(1) represents the instruction to apply Hadamard to qubit 1.

Args:
target: Tuple of ints describing the qubits to apply to.
Expand All @@ -96,7 +96,7 @@ def S(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""S gate or constant phase gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
By providing tuple of ints to 'control', it turns into a controlled gate.

Example usage: S(1) applies S to qubit 1.
Example usage: S(1) represents the instruction to apply S to qubit 1.

Args:
target: Tuple of ints describing the qubits to apply to.
Expand All @@ -112,7 +112,7 @@ def T(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""T gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
By providing tuple of ints to 'control', it turns into a controlled gate.

Example usage: T(1) applies Hadamard to qubit 1.
Example usage: T(1) represents the instruction to apply Hadamard to qubit 1.

Args:
target: Tuple of ints describing the qubits to apply to.
Expand Down
18 changes: 18 additions & 0 deletions horqrux/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Iterable, Tuple, Union

import jax
Expand All @@ -16,6 +17,23 @@
ATOL = 1e-014


class StrEnum(str, Enum):
def __str__(self) -> str:
"""Used when dumping enum fields in a schema."""
ret: str = self.value
return ret

@classmethod
def list(cls) -> list[str]:
return list(map(lambda c: c.value, cls)) # type: ignore


class OperationType(StrEnum):
UNITARY = "unitary"
DAGGER = "dagger"
JACOBIAN = "jacobian"


def _dagger(operator: Array) -> Array:
return jnp.conjugate(operator.T)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = [
requires-python = ">=3.9,<3.12"
license = {text = "Apache 2.0"}

version = "0.4.0"
version = "0.5.0"

classifiers=[
"License :: Other/Proprietary License",
Expand Down
37 changes: 37 additions & 0 deletions tests/test_adjoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import jax.numpy as jnp
import numpy as np
from jax import Array, grad

from horqrux import random_state
from horqrux.adjoint import adjoint_expectation, expectation
from horqrux.parametric import PHASE, RX, RY, RZ
from horqrux.primitive import NOT, H, I, S, T, X, Y, Z

MAX_QUBITS = 7
PARAMETRIC_GATES = (RX, RY, RZ, PHASE)
PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T)


def test_gradcheck() -> None:
ops = [RX("theta", 0), RY("epsilon", 0), RX("phi", 0), NOT(1, 0), RX("omega", 0, 1)]
observable = [Z(0)]
values = {
"theta": np.random.uniform(0, 1),
"epsilon": np.random.uniform(0, 1),
"phi": np.random.uniform(0, 1),
"omega": np.random.uniform(0, 1),
}
state = random_state(MAX_QUBITS)

def adjoint_expfn(values) -> Array:
return adjoint_expectation(state, ops, observable, values)

def ad_expfn(values) -> Array:
return expectation(state, ops, observable, values)

grads_adjoint = grad(adjoint_expfn)(values)
grad_ad = grad(ad_expfn)(values)
for param, ad_grad in grad_ad.items():
assert jnp.isclose(grads_adjoint[param], ad_grad, atol=0.09)
Loading