Skip to content

Commit

Permalink
New tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Jan 24, 2024
1 parent 059a717 commit 3bb4f11
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 21 deletions.
11 changes: 11 additions & 0 deletions docs/docsutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import annotations

from io import StringIO

from matplotlib.figure import Figure


def fig_to_html(fig: Figure) -> str:
buffer = StringIO()
fig.savefig(buffer, format="svg")
return buffer.getvalue()
102 changes: 81 additions & 21 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,34 +75,94 @@ new_state = apply_gate(state, RX(param_value, target_qubit, control_qubit))

A fully differentiable variational circuit is simply a sequence of gates which are applied to a state.

```python exec="on" source="material-block"
```python exec="on" source="material-block" html="1"
import jax
from jax import grad, jit, Array, value_and_grad, vmap
import jax.numpy as jnp
from horqrux import parametric, primitive
import optax
from itertools import chain
from functools import reduce, partial
from operator import add
from typing import Callable
import matplotlib.pyplot as plt
from horqrux.abstract import Operator
from horqrux import Z, RX, RY, NOT
from horqrux.utils import zero_state, overlap
from horqrux.apply import apply_gate

n_qubits = 2
n_qubits = 5
# Lets define a sequence of rotations
ops = [parametric.RX, parametric.RY, parametric.RX]
n_params = 3
n_layers = 3
param_names = [f'theta_{i}' for i in range(n_params * n_layers * n_qubits)]
qubits = [q for _ in range(n_params) for q in range(n_qubits) for _ in range(n_layers)]

rots = [(RX,RY,RX) for _ in range(n_layers * n_qubits)]
# breakpoint()
ops = [fn(param, qubit) for fn, param, qubit in zip([item for tple in rots for item in tple], param_names, qubits)]
# Create random initial values for the parameters
key = jax.random.PRNGKey(0)
params = jax.random.uniform(key, shape=(n_qubits * len(ops),))

def circ(params: jax.Array, state=zero_state(2)) -> jax.Array:
for qubit in range(n_qubits):
for gate,param in zip(ops, params):
state = apply_gate(state, gate(param, qubit))
state = apply_gate(state, primitive.NOT(1, 0))
projection = apply_gate(state, primitive.Z(0))
key = jax.random.PRNGKey(42)
param_vals = jax.random.uniform(key, shape=(len(ops),))
param_dict = {name: val for name, val in zip(param_names, param_vals)}
# We will use a featuremap of RX rotations to encode some classical data
x = jnp.linspace(0, 10, 100)
# We need a function which runs our circuit
def circ(param_dict: dict[str, float], rotations: list[Operator] = ops, n_qubits: int=n_qubits) -> jax.Array:
feature_map = [RX('phi', i) for i in range(n_qubits)]
entangling = [NOT((i+1) % n_qubits, i % n_qubits) for i in range(n_qubits)]
observable = [Z(i) for i in range(n_qubits)]
state = apply_gate(zero_state(n_qubits), feature_map + rotations + entangling, param_dict)
projection = apply_gate(state, observable)
return overlap(state, projection)

# Let's compute both values and gradients for a set of parameters and compile the circuit.
circ = jax.jit(jax.value_and_grad(circ))
# Run it on a state.
expval_and_grads = circ(params)
expval = expval_and_grads[0]
grads = expval_and_grads[1:]
print(f'Expval: {expval};'
f'Grads: {grads}')
# Lets create a convenience lambda fn to use for forward passes.
expfn = lambda p, v: circ({**p, **{'phi': v}})

# Check the initial predictions using randomly initialized parameters
y_init = vmap(partial(expfn, param_dict), in_axes=(0,))(x)
# Let's compute both values and gradients for a set of parameters.
expval_and_grads = value_and_grad(lambda p: expfn(p, 1.))(param_dict)

# We can also train our model to fit a function

fn = lambda x, degree: .05 * reduce(add, (jnp.cos(i*x) + jnp.sin(i*x) for i in range(degree)), 0)
DEGREE = 5
y = fn(x, DEGREE)

optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(param_dict)


def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state


def loss_fn(params: dict, v: float) -> float:
return (expfn(params, v) - fn(v, DEGREE)) ** 2

@jit
def train_step(i:int, inputs: tuple
) -> tuple:
params, opt_state = inputs
def vng(v: float) -> tuple:
return value_and_grad(partial(loss_fn, v=v))(param_dict)
loss, grads = vmap(vng, in_axes=(0,))(x)
loss, grads = jnp.mean(loss), {param: jnp.mean(grad_vals) for param, grad_vals in grads.items()}
params, opt_state = optimize_step(params, opt_state, grads)
print(f"epoch {i} loss:{loss}")
return params, opt_state

n_epochs = 1000
param_dict, opt_state = jax.lax.fori_loop(0, n_epochs, train_step, (param_dict, opt_state))
y_final = vmap(partial(expfn, param_dict), in_axes=(0,))(x)

# Lets plot the results
plt.plot(x, y, label="truth")
plt.plot(x, y_init, label="initial")
plt.plot(x, y_final, "--", label="final", linewidth=3)
plt.legend()
# from docs.docsutils import fig_to_html # markdown-exec: hide
# print(fig_to_html(plt.gcf())) # markdown-exec: hide
```
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ dependencies = [
"mkdocs-exclude",
"markdown-exec",
"mike",
"matplotlib",
]

[tool.hatch.build.targets.wheel]
Expand Down

0 comments on commit 3bb4f11

Please sign in to comment.