diff --git a/piquasso/_backends/fock/pure/calculations.py b/piquasso/_backends/fock/pure/calculations.py index 0a532878..aea79313 100644 --- a/piquasso/_backends/fock/pure/calculations.py +++ b/piquasso/_backends/fock/pure/calculations.py @@ -15,6 +15,8 @@ from typing import Optional, Tuple, Mapping, List, Callable +from functools import lru_cache + import random import numpy as np @@ -88,6 +90,7 @@ def _get_interferometer_with_gradient_callback(interferometer): return wrapped(interferometer) +@lru_cache(maxsize=None) def _calculate_interferometer_helper_indices(space): d = space.d cutoff = space.cutoff @@ -191,16 +194,19 @@ def _calculate_interferometer_on_fock_space(interferometer, index_dict): sqrt_occupation_numbers = index_dict["sqrt_occupation_numbers_tensor"][n - 2] first_occupation_numbers = index_dict["first_occupation_numbers_tensor"][n - 2] + first_part_partially_indexed = interferometer[first_nonzero_indices, :] + second_part_partially_indexed = subspace_representations[n - 1][ + first_subspace_indices, : + ] + matrix = [] for index in range(size): first_part = ( sqrt_occupation_numbers[index] - * interferometer[np.ix_(first_nonzero_indices, nonzero_indices[index])] + * first_part_partially_indexed[:, nonzero_indices[index]] ) - second_part = subspace_representations[n - 1][ - np.ix_(first_subspace_indices, subspace_indices[index]) - ] + second_part = second_part_partially_indexed[:, subspace_indices[index]] matrix.append(np.einsum("ij,ij->i", first_part, second_part)) new_subspace_representation = np.transpose( @@ -270,28 +276,32 @@ def interferometer_gradient(*upstream): sqrt_occupation_numbers = sqrt_occupation_numbers_tensor[p - 2] first_occupation_numbers = first_occupation_numbers_tensor[p - 2] + partial_interferometer = interferometer[first_nonzero_indices, :] + partial_previous_subspace_grad = previous_subspace_grad[ + first_subspace_indices, : + ] + for n_index in range(size): first_part = ( sqrt_occupation_numbers[n_index] - * interferometer[ - fallback_np.ix_( - first_nonzero_indices, nonzero_indices[n_index] - ) - ] + * partial_interferometer[:, nonzero_indices[n_index]] ) - second_part = previous_subspace_grad[ - fallback_np.ix_( - first_subspace_indices, subspace_indices[n_index] - ) + second_part = partial_previous_subspace_grad[ + :, subspace_indices[n_index] ] full = fallback_np.einsum("ij,ij->i", first_part, second_part) - matrix[:, n_index] = full / fallback_np.sqrt( - first_occupation_numbers - ) + matrix[:, n_index] = full + + matrix = (matrix.T / fallback_np.sqrt(first_occupation_numbers)).T mp1i_indices = fallback_np.where( fallback_np.asarray(first_nonzero_indices) == row_index )[0] + + occupation_number_sqrts = fallback_np.sqrt( + first_occupation_numbers[mp1i_indices] + ) + col_nonzero_indices_index = [] col_nonzero_indices = [] for index in range(size): @@ -309,7 +319,7 @@ def interferometer_gradient(*upstream): sqrt_occupation_numbers[col_nonzero_indices[index]][ col_nonzero_indices_index[index] ] - / fallback_np.sqrt(first_occupation_numbers[mp1i_indices]) + / occupation_number_sqrts * subspace_representations[p - 1][ first_subspace_indices[mp1i_indices], nm1l_index ] @@ -406,7 +416,6 @@ def _apply_interferometer_matrix(state_vector, subspace_transformations): index_list = _calculate_index_list_for_appling_interferometer( modes, space, - calculator, ) new_state_vector = _calculate_state_vector_after_interferometer( @@ -428,13 +437,14 @@ def _apply_interferometer_matrix(state_vector, subspace_transformations): state._state_vector = wrapped(state_vector, subspace_transformations) +@lru_cache(maxsize=None) def _calculate_index_list_for_appling_interferometer( modes: Tuple[int, ...], space: FockSpace, - calculator: BaseCalculator, ) -> List[np.ndarray]: cutoff = space.cutoff d = space.d + calculator = space._calculator subspace = FockSpace( d=len(modes), cutoff=space.cutoff, calculator=calculator, config=space.config @@ -569,6 +579,7 @@ def _apply_active_gate_matrix(state_vector, matrix): state._state_vector = _apply_active_gate_matrix(state_vector, matrix) +@lru_cache(maxsize=None) def _calculate_state_index_matrix_list(space, auxiliary_subspace, mode): d = space.d cutoff = space.cutoff @@ -655,12 +666,19 @@ def linear_active_gate_gradient(upstream): (np.einsum("ji,jk->ik", matrix_slice, upstream_slice)).reshape(-1) ) - gradient_by_matrix += np.pad( - np.einsum("ij,kj->ki", state_vector_slice, upstream_slice), - [[0, cutoff - limit], [0, cutoff - limit]], - "constant", + partial_gradient = np.einsum( + "ij,kj->ki", state_vector_slice, upstream_slice ) + if static_valued: + gradient_by_matrix[:limit, :limit] += partial_gradient + else: + gradient_by_matrix += np.pad( + partial_gradient, + [[0, cutoff - limit], [0, cutoff - limit]], + "constant", + ) + gradient_by_initial_state = np.take( np.concatenate(unordered_gradient_by_initial_state), fallback_np.concatenate(order_by).argsort(),