Skip to content

Commit

Permalink
perf(tensorflow): Interferometer gradient
Browse files Browse the repository at this point in the history
Some indexing has been moved outside of the for loop to improve performance.

Moreover, if the upstream is static valued, then we calculate with numpy,
meaning that we don't need to pad the matrices, since in numpy we can do item
assignments.

Finally, some caching has been added to some index calculation functions.
  • Loading branch information
Kolarovszki committed Feb 3, 2024
1 parent 873f464 commit 56d36c2
Showing 1 changed file with 41 additions and 23 deletions.
64 changes: 41 additions & 23 deletions piquasso/_backends/fock/pure/calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from typing import Optional, Tuple, Mapping, List, Callable

from functools import lru_cache

import random
import numpy as np

Expand Down Expand Up @@ -88,6 +90,7 @@ def _get_interferometer_with_gradient_callback(interferometer):
return wrapped(interferometer)


@lru_cache(maxsize=None)
def _calculate_interferometer_helper_indices(space):
d = space.d
cutoff = space.cutoff
Expand Down Expand Up @@ -191,16 +194,19 @@ def _calculate_interferometer_on_fock_space(interferometer, index_dict):
sqrt_occupation_numbers = index_dict["sqrt_occupation_numbers_tensor"][n - 2]
first_occupation_numbers = index_dict["first_occupation_numbers_tensor"][n - 2]

first_part_partially_indexed = interferometer[first_nonzero_indices, :]
second_part_partially_indexed = subspace_representations[n - 1][
first_subspace_indices, :
]

matrix = []

for index in range(size):
first_part = (
sqrt_occupation_numbers[index]
* interferometer[np.ix_(first_nonzero_indices, nonzero_indices[index])]
* first_part_partially_indexed[:, nonzero_indices[index]]
)
second_part = subspace_representations[n - 1][
np.ix_(first_subspace_indices, subspace_indices[index])
]
second_part = second_part_partially_indexed[:, subspace_indices[index]]
matrix.append(np.einsum("ij,ij->i", first_part, second_part))

new_subspace_representation = np.transpose(
Expand Down Expand Up @@ -270,28 +276,32 @@ def interferometer_gradient(*upstream):
sqrt_occupation_numbers = sqrt_occupation_numbers_tensor[p - 2]
first_occupation_numbers = first_occupation_numbers_tensor[p - 2]

partial_interferometer = interferometer[first_nonzero_indices, :]
partial_previous_subspace_grad = previous_subspace_grad[
first_subspace_indices, :
]

for n_index in range(size):
first_part = (
sqrt_occupation_numbers[n_index]
* interferometer[
fallback_np.ix_(
first_nonzero_indices, nonzero_indices[n_index]
)
]
* partial_interferometer[:, nonzero_indices[n_index]]
)
second_part = previous_subspace_grad[
fallback_np.ix_(
first_subspace_indices, subspace_indices[n_index]
)
second_part = partial_previous_subspace_grad[
:, subspace_indices[n_index]
]
full = fallback_np.einsum("ij,ij->i", first_part, second_part)
matrix[:, n_index] = full / fallback_np.sqrt(
first_occupation_numbers
)
matrix[:, n_index] = full

matrix = (matrix.T / fallback_np.sqrt(first_occupation_numbers)).T

mp1i_indices = fallback_np.where(
fallback_np.asarray(first_nonzero_indices) == row_index
)[0]

occupation_number_sqrts = fallback_np.sqrt(
first_occupation_numbers[mp1i_indices]
)

col_nonzero_indices_index = []
col_nonzero_indices = []
for index in range(size):
Expand All @@ -309,7 +319,7 @@ def interferometer_gradient(*upstream):
sqrt_occupation_numbers[col_nonzero_indices[index]][
col_nonzero_indices_index[index]
]
/ fallback_np.sqrt(first_occupation_numbers[mp1i_indices])
/ occupation_number_sqrts
* subspace_representations[p - 1][
first_subspace_indices[mp1i_indices], nm1l_index
]
Expand Down Expand Up @@ -406,7 +416,6 @@ def _apply_interferometer_matrix(state_vector, subspace_transformations):
index_list = _calculate_index_list_for_appling_interferometer(
modes,
space,
calculator,
)

new_state_vector = _calculate_state_vector_after_interferometer(
Expand All @@ -428,13 +437,14 @@ def _apply_interferometer_matrix(state_vector, subspace_transformations):
state._state_vector = wrapped(state_vector, subspace_transformations)


@lru_cache(maxsize=None)
def _calculate_index_list_for_appling_interferometer(
modes: Tuple[int, ...],
space: FockSpace,
calculator: BaseCalculator,
) -> List[np.ndarray]:
cutoff = space.cutoff
d = space.d
calculator = space._calculator

subspace = FockSpace(
d=len(modes), cutoff=space.cutoff, calculator=calculator, config=space.config
Expand Down Expand Up @@ -569,6 +579,7 @@ def _apply_active_gate_matrix(state_vector, matrix):
state._state_vector = _apply_active_gate_matrix(state_vector, matrix)


@lru_cache(maxsize=None)
def _calculate_state_index_matrix_list(space, auxiliary_subspace, mode):
d = space.d
cutoff = space.cutoff
Expand Down Expand Up @@ -655,12 +666,19 @@ def linear_active_gate_gradient(upstream):
(np.einsum("ji,jk->ik", matrix_slice, upstream_slice)).reshape(-1)
)

gradient_by_matrix += np.pad(
np.einsum("ij,kj->ki", state_vector_slice, upstream_slice),
[[0, cutoff - limit], [0, cutoff - limit]],
"constant",
partial_gradient = np.einsum(
"ij,kj->ki", state_vector_slice, upstream_slice
)

if static_valued:
gradient_by_matrix[:limit, :limit] += partial_gradient
else:
gradient_by_matrix += np.pad(
partial_gradient,
[[0, cutoff - limit], [0, cutoff - limit]],
"constant",
)

gradient_by_initial_state = np.take(
np.concatenate(unordered_gradient_by_initial_state),
fallback_np.concatenate(order_by).argsort(),
Expand Down

0 comments on commit 56d36c2

Please sign in to comment.