Skip to content

Commit

Permalink
optimize shots align eigenvectors (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
chMoussa authored Mar 3, 2025
1 parent 2318333 commit 53be041
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions horqrux/shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def eigen_sample(
observables,
)
)
eigs = list(map(lambda mat: jnp.linalg.eigh(mat), mat_obs))
eigvecs, eigvals = align_eigenvectors(eigs)
eigs = jax.vmap(jnp.linalg.eigh)(jnp.stack(mat_obs))
eigvecs, eigvals = align_eigenvectors(eigs.eigenvalues, eigs.eigenvectors)
probs = eigen_probabilities(state, eigvecs)
return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0)

Expand All @@ -124,7 +124,7 @@ def finite_shots_fwd(
return eigen_sample(output_gates, observables, values, n_qubits, n_shots, key)


def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]:
def align_eigenvectors(eigenvalues: Array, eigenvectors: Array) -> tuple[Array, Array]:
"""
Given a list of eigenvalue eigenvector matrix tuples in the form of
[(eigenvalue, eigenvector)...], this function aligns all the eigenvector
Expand All @@ -139,20 +139,19 @@ def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]:
matrix and uses it to align each eigenvector matrix to the first eigenvector
matrix of eigs.
"""
eigenvalues = []
eigs_copy = eigs.copy()
eigenvalue, eigenvector_matrix = eigs_copy.pop(0)
eigenvalues.append(eigenvalue)
# TODO: laxify this loop
for mat in eigs_copy:
inv = jnp.linalg.inv(mat[1])
P = (inv @ eigenvector_matrix).real > 0.5
checkify.check(
validate_permutation_matrix(P),
"Did not calculate valid permutation matrix",
)
eigenvalues.append(mat[0] @ P)
return eigenvector_matrix, jnp.stack(eigenvalues, axis=1)
eigenvector_matrix = eigenvectors[0]

P = jax.vmap(lambda mat: permutation_matrix(mat, eigenvector_matrix))(eigenvectors)
checkify.check(
jnp.all(jax.vmap(validate_permutation_matrix)(P)),
"Did not calculate valid permutation matrix",
)
eigenvalues_aligned = jax.vmap(jnp.dot)(eigenvalues, P).T
return eigenvector_matrix, eigenvalues_aligned


def permutation_matrix(mat: Array, eigenvector_matrix: Array) -> Array:
return (jnp.linalg.inv(mat) @ eigenvector_matrix).real > 0.5


def validate_permutation_matrix(P: Array) -> Array:
Expand Down

0 comments on commit 53be041

Please sign in to comment.