forked from qBraid/NYUAD-2023
-
Notifications
You must be signed in to change notification settings - Fork 2
/
qml_debugging.py
174 lines (145 loc) · 5.06 KB
/
qml_debugging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import numpy as np
import pandas as pd
import haiku as hk
import jax
import optax
from sklearn.metrics import accuracy_score
import pennylane as qml
from pennylane import numpy as np
import jax
from jax import random
import haiku as hk
# Load data
x_train = pd.read_excel("x_train.xlsx")
y_train = pd.read_excel("y_train.xlsx")
x_test = pd.read_excel("x_test.xlsx")
y_test = pd.read_excel("y_test.xlsx")
n_qubits = 8
epochs = 500
batch_size = 32
num_layers = 8
num_batches = len(x_train) // batch_size
dev = qml.device("default.qubit", wires=n_qubits)
def quantum_layer(weights):
"""
Generates a layer in the QNN Model
Args:
weights (List[floats]): Parameters to be fed into the layer
"""
qml.templates.AngleEmbedding(weights[:, 0], rotation="Y", wires=range(n_qubits))
qml.templates.AngleEmbedding(weights[:, 1], rotation="Z", wires=range(n_qubits))
for i in range(8):
qml.CNOT(wires=[i, (i + 1) % 8])
@qml.qnode(dev, interface="jax")
def quantum_circuit(x, circuit_weights):
"""
Builds the entire Quantum circuit model
Args:
x (List[float]): The pressure sensor data to be fed into the model
circuit_weights (List[float]): Parameters to be fed into the model
Returns:
List[float]: Returns the expected value of qubit measurement in PauliZ basis
"""
for weights in circuit_weights:
qml.templates.AngleEmbedding(x, wires=range(n_qubits))
quantum_layer(weights)
return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]
@hk.without_apply_rng
@hk.transform
def qforward(x):
"""
Does a forward pass of the Quantum Model
Args:
x (List[float]): The pressure sensor data to be fed into the model
Returns:
List[Float]: Returns the logits from the model
"""
x = jax.nn.tanh(hk.Linear(8)(x))
W = hk.get_parameter(
"W", (num_layers, 8, 3), init=hk.initializers.RandomNormal(stddev=0.25)
)
x = jax.vmap(quantum_circuit, in_axes=(0, None))(x, W)
x = hk.Linear(1)(x)
return x
@hk.without_apply_rng
@hk.transform
def cforward(x):
"""
Does a forward pass of the Classical Model
Args:
x (List[float]): The pressure sensor data to be fed into the model
Returns:
List[Float]: Returns the logits from the model
"""
nn = hk.Sequential([hk.Linear(10),
jax.nn.relu,
hk.Linear(10),
jax.nn.relu,
hk.Linear(1)])
return nn(x)
seed = 123
rng = jax.random.PRNGKey(seed)
params = qforward.init(rng, x_train.values)
opt = optax.radam(learning_rate=5e-4)
opt_state = opt.init(params)
# Training loop
def loss_fn(params, x, y):
"""
Calculates the loss value between predicted output after a forward pass of the model and the ground truth
Args:
params (List[Float]): Parameters to be fed into the model
x (List[float]): The pressure sensor data to be fed into the model
y (List[int]): The ground truth class data
Returns:
float: Loss value between predicted output and the ground truth
"""
pred = qforward.apply(params, x)
loss = optax.sigmoid_binary_cross_entropy(pred, y).mean()
return loss
@jax.jit
def update(params, opt_state, x, y):
"""
Updates the parameters based on the gradients calculated
Args:
params (List[float]): Parameters to be fed into the model
opt_state (List[float]): The optimizer state at the last epoch
x (List[float]): The pressure sensor data to be fed into the model
y (List[int]): The ground truth class data
Returns:
List[float], List[float], float: Returns the list of updated parameters, new state of the optimizer and loss value
"""
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, new_opt_state = opt.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss
loss_list = []
test_acc = []
for epoch in range(epochs):
# Shuffle the training data
shuffled_indices = np.random.permutation(len(x_train))
x_train_shuffled = x_train.values[shuffled_indices]
y_train_shuffled = y_train.values[shuffled_indices]
# Training
epoch_loss = 0
for batch_idx in range(num_batches):
start = batch_idx * batch_size
end = start + batch_size
x_batch = x_train_shuffled[start:end]
y_batch = y_train_shuffled[start:end]
params, opt_state, batch_loss = update(params, opt_state, x_batch, y_batch)
epoch_loss += batch_loss
epoch_loss /= num_batches
loss_list.append(epoch_loss)
# Testing
y_pred = qforward.apply(params, x_test.values)
y_pred_labels = (y_pred > 0.5).astype(int)
test_accuracy = accuracy_score(y_test, y_pred_labels)
print(
f"Epoch {epoch + 1}, Loss: {epoch_loss:.4f}"
)
test_acc.append(test_accuracy)
# Testing
y_pred = qforward.apply(params, x_test.values)
y_pred_labels = (y_pred > 0.5).astype(int)
test_accuracy = accuracy_score(y_test.values, y_pred_labels)
print(f"Test Accuracy: {test_accuracy:.4f}")