Skip to content

Commit

Permalink
typing corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Apr 8, 2024
1 parent 80b2c0a commit 42d37fd
Showing 1 changed file with 33 additions and 32 deletions.
65 changes: 33 additions & 32 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ psi_star = apply_gate(psi, hamevo, {"hamiltonian": Hamiltonian, "time_evolution"

We can now build a fully differentiable variational circuit by simply defining a sequence of gates
and a set of initial parameter values we want to optimize.
Horqrux provides an implementation of [adjoint differentiation](https://arxiv.org/abs/2009.02823),
which we can use to fit a function using a simple circuit class wrapper.
`horqrux` provides an implementation of [adjoint differentiation](https://arxiv.org/abs/2009.02823),
which we can use to fit a function using a simple `Circuit` class.

```python exec="on" source="material-block" html="1"
from __future__ import annotations
Expand Down Expand Up @@ -178,15 +178,15 @@ def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
return jnp.mean(optax.l2_loss(y_pred, y))


def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple:
def optimize_step(param_vals: Array, opt_state: Array, grads: Array) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
param_vals = optax.apply_updates(param_vals, updates)
return param_vals, opt_state

@jit
def train_step(i: int, inputs: tuple
def train_step(i: int, paramvals_w_optstate: tuple
) -> tuple:
param_vals, opt_state = inputs
param_vals, opt_state = paramvals_w_optstate
loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
param_vals, opt_state = optimize_step(param_vals, opt_state, grads)
return param_vals, opt_state
Expand Down Expand Up @@ -214,7 +214,7 @@ print(fig_to_html(plt.gcf())) # markdown-exec: hide
```
## Fitting a partial differential equation using DQC

Finally, we show how to implement [DQC](https://arxiv.org/abs/2011.10395) to solve a partial differential equation.
Finally, we show how [DQC](https://arxiv.org/abs/2011.10395) can be implemented in `horqrux` and solve a partial differential equation.

```python exec="on" source="material-block" html="1"
from __future__ import annotations
Expand All @@ -241,6 +241,8 @@ LEARNING_RATE = 0.01
N_QUBITS = 4
DEPTH = 3
VARIABLES = ("x", "y")
X_POS = 0
Y_POS = 1
N_POINTS = 150
N_EPOCHS = 1000

Expand Down Expand Up @@ -288,17 +290,17 @@ class Circuit:
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
self.observable = TotalMagnetization(self.n_qubits)

def forward(self, param_values: Array, x: Array, y: Array) -> Array:
def forward(self, param_vals: Array, x: Array, y: Array) -> Array:
state = zero_state(self.n_qubits)
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
param_dict = {name: val for name, val in zip(self.param_names, param_vals)}
out_state = apply_gate(
state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}}
)
projected_state = self.observable(state, param_dict)
return jnp.real(inner(out_state, projected_state))

def __call__(self, param_values: Array, x: Array, y: Array) -> Array:
return self.forward(param_values, x, y)
def __call__(self, param_vals: Array, x: Array, y: Array) -> Array:
return self.forward(param_vals, x, y)

@property
def n_vparams(self) -> int:
Expand All @@ -308,21 +310,21 @@ class Circuit:
circ = Circuit(N_QUBITS, DEPTH)
# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
params = jax.random.uniform(key, shape=(circ.n_vparams,))
param_vals = jax.random.uniform(key, shape=(circ.n_vparams,))

optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)
opt_state = optimizer.init(param_vals)


def exp_fn(param_values: Array, x: Array, y: Array) -> Array:
return circ(param_values, x, y)
def exp_fn(param_vals: Array, x: Array, y: Array) -> Array:
return circ(param_vals, x, y)


def loss_fn(param_values: Array, x: Array, y: Array) -> Array:
def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
def pde_loss(x: float, y: float) -> Array:
l_b, r_b, t_b, b_b = list(
map(
lambda inputs: exp_fn(param_values, *inputs),
lambda xy: exp_fn(param_vals, *xy),
[
[jnp.zeros((1, 1)), y], # u(0,y)=0
[jnp.ones((1, 1)), y], # u(L,y)=0
Expand All @@ -332,7 +334,7 @@ def loss_fn(param_values: Array, x: Array, y: Array) -> Array:
)
)
b_b -= jnp.sin(jnp.pi * x)
hessian = jax.hessian(lambda inputs: exp_fn(params, inputs[0], inputs[1]))(
hessian = jax.hessian(lambda xy: exp_fn(param_vals, xy[0], xy[1]))(
jnp.concatenate(
[
x.reshape(
Expand All @@ -344,16 +346,16 @@ def loss_fn(param_values: Array, x: Array, y: Array) -> Array:
]
)
)
interior = hessian[0][0] + hessian[1][1] # uxx+uyy=0
return reduce(add, list(map(lambda t: jnp.power(t, 2), [l_b, r_b, t_b, b_b, interior])))
interior = hessian[X_POS][X_POS] + hessian[Y_POS][Y_POS] # uxx+uyy=0
return reduce(add, list(map(lambda term: jnp.power(term, 2), [l_b, r_b, t_b, b_b, interior])))

return jnp.mean(vmap(pde_loss, in_axes=(0, 0))(x, y))


def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state
def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state, param_vals)
param_vals = optax.apply_updates(param_vals, updates)
return param_vals, opt_state


# collocation points sampling and training
Expand All @@ -362,15 +364,14 @@ def sample_points(n_in: int, n_p: int) -> Array:


@jit
def train_step(i: int, inputs: tuple) -> tuple:
params, opt_state = inputs
def train_step(i: int, paramvals_w_optstate: tuple) -> tuple:
param_vals, opt_state = paramvals_w_optstate
x, y = sample_points(2, N_POINTS)
loss, grads = value_and_grad(loss_fn)(params, x, y)
params, opt_state = optimize_step(params, opt_state, grads)
return params, opt_state
loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
return optimize_step(param_vals, opt_state, grads)


params, opt_state = jax.lax.fori_loop(0, N_EPOCHS, train_step, (params, opt_state))
param_vals, opt_state = jax.lax.fori_loop(0, N_EPOCHS, train_step, (param_vals, opt_state))
# compare the solution to known ground truth
single_domain = jnp.linspace(0, 1, num=N_POINTS)
domain = jnp.array(list(product(single_domain, single_domain)))
Expand All @@ -380,7 +381,7 @@ analytic_sol = (
)
# DQC solution

dqc_sol = vmap(lambda domain: exp_fn(params, domain[0], domain[1]), in_axes=(0,))(domain).reshape(
dqc_sol = vmap(lambda domain: exp_fn(param_vals, domain[0], domain[1]), in_axes=(0,))(domain).reshape(
N_POINTS, N_POINTS
)
# # plot results
Expand Down

0 comments on commit 42d37fd

Please sign in to comment.