Skip to content

Commit

Permalink
feat: Batch processing of pure Fock states
Browse files Browse the repository at this point in the history
The batch processing of pure Fock states has been implemented. This is done by
storing multiple states in a tensor of shape `(state_vector_size,
number_of_batches)` in `BatchPureFockState`. However, `BatchPureFockState` has
a limited support of methods compared to the original `PureFockState`.
  • Loading branch information
Kolarovszki committed Feb 10, 2024
1 parent 39e4146 commit 8a9a911
Show file tree
Hide file tree
Showing 13 changed files with 1,316 additions and 50 deletions.
10 changes: 10 additions & 0 deletions piquasso/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from piquasso._backends.fock import (
FockState,
PureFockState,
BatchPureFockState,
FockSimulator,
PureFockSimulator,
)
Expand Down Expand Up @@ -90,6 +91,11 @@
LossyInterferometer,
)

from .instructions.batch import (
BatchPrepare,
BatchApply,
)


__all__ = [
# API
Expand All @@ -115,6 +121,7 @@
"SamplingState",
"FockState",
"PureFockState",
"BatchPureFockState",
# Preparations
"Vacuum",
"Mean",
Expand Down Expand Up @@ -154,6 +161,9 @@
"Attenuator",
"Loss",
"LossyInterferometer",
# Batch
"BatchPrepare",
"BatchApply",
]

__version__ = "3.0.0"
1 change: 1 addition & 0 deletions piquasso/_backends/fock/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from .general.simulator import FockSimulator # noqa: F401

from .pure.state import PureFockState # noqa: F401
from .pure.batch_state import BatchPureFockState # noqa: F401
from .pure.simulator import PureFockSimulator # noqa: F401
149 changes: 149 additions & 0 deletions piquasso/_backends/fock/pure/batch_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#
# Copyright 2021-2023 Budapest Quantum Computing Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import numpy as np

from piquasso.api.config import Config
from piquasso.api.exceptions import InvalidState
from piquasso.api.calculator import BaseCalculator

from piquasso._math.linalg import vector_absolute_square
from piquasso._math.indices import get_index_in_fock_space

from .state import PureFockState


class BatchPureFockState(PureFockState):
r"""A simulated batch pure Fock state, containing multiple state vectors."""

def __init__(
self, *, d: int, calculator: BaseCalculator, config: Optional[Config] = None
) -> None:
"""
Args:
d (int): The number of modes.
calculator (BaseCalculator): Instance containing calculation functions.
config (Config): Instance containing constants for the simulation.
"""

super().__init__(d=d, calculator=calculator, config=config)

def _apply_separate_state_vectors(self, state_vectors):
self._state_vector = self._np.array(
state_vectors, dtype=self._config.complex_dtype
).T

@property
def _batch_size(self):
return self._state_vector.shape[1]

@property
def _batch_state_vectors(self):
for index in range(self._batch_size):
yield self._state_vector[:, index]

@property
def nonzero_elements(self):
return [
self._nonzero_elements_for_single_state_vector(state_vector)
for state_vector in self._batch_state_vectors
]

def __repr__(self) -> str:
partial_strings = []
for partial_nonzero_elements in self.nonzero_elements:
partial_strings.append(
self._get_repr_for_single_state_vector(partial_nonzero_elements)
)

return "\n".join(partial_strings)

def __eq__(self, other: object) -> bool:
if not isinstance(other, BatchPureFockState):
return False
return self._np.allclose(self._state_vector, other._state_vector)

@property
def fock_probabilities(self) -> np.ndarray:
return [
vector_absolute_square(state_vector, self._calculator)
for state_vector in self._batch_state_vectors
]

@property
def norm(self):
return [
self._calculator.np.sum(partial_fock_probabilities)
for partial_fock_probabilities in self.fock_probabilities
]

def normalize(self) -> None:
if not self._config.normalize:
return

norms = self.norm

if any(np.isclose(norm, 0) for norm in norms):
raise InvalidState("The norm of a state in the batch is 0.")

self._state_vector = self._state_vector / self._np.sqrt(norms)

def validate(self) -> None:
return True

def _get_mean_position_indices(self, mode):
fallback_np = self._calculator.fallback_np

left_indices = []
multipliers = []
right_indices = []

for index, basis in enumerate(self._space):
i = basis[mode]
basis_array = fallback_np.array(basis)

if i > 0:
basis_array[mode] = i - 1
lower_index = get_index_in_fock_space(tuple(basis_array))

left_indices.append(lower_index)
multipliers.append(fallback_np.sqrt(i))
right_indices.append(index)

if sum(basis) + 1 < self._config.cutoff:
basis_array[mode] = i + 1
upper_index = get_index_in_fock_space(tuple(basis_array))

left_indices.append(upper_index)
multipliers.append(fallback_np.sqrt(i + 1))
right_indices.append(index)

multipliers = fallback_np.array(multipliers)

return multipliers, left_indices, right_indices

def mean_position(self, mode: int) -> np.ndarray:
np = self._calculator.np
fallback_np = self._calculator.fallback_np
multipliers, left_indices, right_indices = self._get_mean_position_indices(mode)

lhs = (multipliers * self._state_vector[left_indices].T).T
rhs = self._state_vector[right_indices]

return np.real(
np.einsum("ij,ij->j", lhs, rhs) * fallback_np.sqrt(self._config.hbar / 2)
)
114 changes: 90 additions & 24 deletions piquasso/_backends/fock/pure/calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from piquasso.api.calculator import BaseCalculator

from .state import PureFockState
from .batch_state import BatchPureFockState

from piquasso.instructions import gates

Expand Down Expand Up @@ -492,8 +493,14 @@ def _calculate_state_vector_after_interferometer(
) -> np.ndarray:
new_state_vector = np.empty_like(state_vector)

is_batch = len(state_vector.shape) == 2

einsum_string = "ij,jkl->ikl" if is_batch else "ij,jk->ik"

for n, indices in enumerate(index_list):
new_state_vector[indices] = subspace_transformations[n] @ state_vector[indices]
new_state_vector[indices] = np.einsum(
einsum_string, subspace_transformations[n], state_vector[indices]
)

return new_state_vector

Expand All @@ -516,34 +523,40 @@ def applying_interferometer_gradient(upstream):
else:
np = calculator.np

is_batch = len(state_vector.shape) == 2

matrix_einsum_string = "ijl,kjl->ki" if is_batch else "ij,kj->ki"
initial_state_einsum_string = "ji,jkl->ikl" if is_batch else "ji,jk->ik"

reshape_arg = (-1, state_vector.shape[1]) if is_batch else (-1,)

unordered_gradient_by_initial_state = []
order_by = []

gradient_by_matrix = []

conjugated_state_vector = np.conj(state_vector)

for n, indices in enumerate(index_list):
matrix = np.conj(subspace_transformations[n])
sliced_upstream = np.take(upstream, indices)
sliced_upstream = upstream[indices]
state_vector_slice = conjugated_state_vector[indices]

order_by.append(indices.reshape(-1))
product = np.einsum("ji,jk->ik", matrix, sliced_upstream)
unordered_gradient_by_initial_state.append(product.reshape(-1))
product = np.einsum(initial_state_einsum_string, matrix, sliced_upstream)
unordered_gradient_by_initial_state.append(product.reshape(*reshape_arg))

gradient_by_matrix.append(
np.einsum("ij,kj->ki", state_vector_slice, sliced_upstream)
np.einsum(matrix_einsum_string, state_vector_slice, sliced_upstream)
)

gradient_by_initial_state = np.take(
np.concatenate(unordered_gradient_by_initial_state),
fallback_np.concatenate(order_by).argsort(),
)
gradient_by_initial_state = np.concatenate(unordered_gradient_by_initial_state)[
fallback_np.concatenate(order_by).argsort()
]

if static_valued:
return (
tf.constant(gradient_by_initial_state),
[tf.constant(matrix) for matrix in gradient_by_matrix],
)
gradient_by_initial_state = tf.constant(gradient_by_initial_state)
gradient_by_matrix = [tf.constant(matrix) for matrix in gradient_by_matrix]

return gradient_by_initial_state, gradient_by_matrix

Expand Down Expand Up @@ -617,10 +630,14 @@ def _calculate_state_vector_after_apply_active_gate(
):
new_state_vector = np.empty_like(state_vector, dtype=state_vector.dtype)

is_batch = len(state_vector.shape) == 2

einsum_string = "ij,jkl->ikl" if is_batch else "ij,jk->ik"

for state_index_matrix in state_index_matrix_list:
limit = state_index_matrix.shape[0]
new_state_vector[state_index_matrix] = (
matrix[:limit, :limit] @ state_vector[state_index_matrix]
new_state_vector[state_index_matrix] = np.einsum(
einsum_string, matrix[:limit, :limit], state_vector[state_index_matrix]
)

return new_state_vector
Expand All @@ -646,6 +663,13 @@ def linear_active_gate_gradient(upstream):

cutoff = len(matrix)

is_batch = len(state_vector.shape) == 2

matrix_einsum_string = "ijl,kjl->ki" if is_batch else "ij,kj->ki"
initial_state_einsum_string = "ji,jkl->ikl" if is_batch else "ji,jk->ik"

reshape_arg = (-1, state_vector.shape[1]) if is_batch else (-1,)

unordered_gradient_by_initial_state = []
order_by = []

Expand All @@ -656,18 +680,19 @@ def linear_active_gate_gradient(upstream):
for indices in state_index_matrix_list:
limit = indices.shape[0]

upstream_slice = np.take(upstream, indices)
upstream_slice = upstream[indices]
state_vector_slice = conjugated_state_vector[indices]

matrix_slice = conjugated_matrix[:limit, :limit]

order_by.append(indices.reshape(-1))
unordered_gradient_by_initial_state.append(
(np.einsum("ji,jk->ik", matrix_slice, upstream_slice)).reshape(-1)
product = np.einsum(
initial_state_einsum_string, matrix_slice, upstream_slice
)
unordered_gradient_by_initial_state.append(product.reshape(*reshape_arg))

partial_gradient = np.einsum(
"ij,kj->ki", state_vector_slice, upstream_slice
matrix_einsum_string, state_vector_slice, upstream_slice
)

if static_valued:
Expand All @@ -679,10 +704,9 @@ def linear_active_gate_gradient(upstream):
"constant",
)

gradient_by_initial_state = np.take(
np.concatenate(unordered_gradient_by_initial_state),
fallback_np.concatenate(order_by).argsort(),
)
gradient_by_initial_state = np.concatenate(unordered_gradient_by_initial_state)[
fallback_np.concatenate(order_by).argsort()
]

if static_valued:
return (
Expand Down Expand Up @@ -725,7 +749,8 @@ def kerr(state: PureFockState, instruction: Instruction, shots: int) -> Result:
1j * xi * np.array([basis[mode] ** 2 for basis in state._space])
)

state._state_vector = coefficients * state._state_vector
# NOTE: Transposition is done here in order to work with batch processing.
state._state_vector = (coefficients * state._state_vector.T).T

return Result(state=state)

Expand Down Expand Up @@ -828,3 +853,44 @@ def _add_occupation_number_basis( # type: ignore
state._state_vector = state._calculator.assign(
state._state_vector, index, coefficient
)


def batch_prepare(state: PureFockState, instruction: Instruction, shots: int):
subprograms = instruction._all_params["subprograms"]
execute = instruction._all_params["execute"]

state_vectors = [
execute(subprogram, shots).state._state_vector for subprogram in subprograms
]

batch_state = BatchPureFockState(
d=state.d, calculator=state._calculator, config=state._config
)

batch_state._apply_separate_state_vectors(state_vectors)

return Result(state=batch_state)


def batch_apply(state: BatchPureFockState, instruction: Instruction, shots: int):
subprograms = instruction._all_params["subprograms"]
execute = instruction._all_params["execute"]

d = state.d
calculator = state._calculator
config = state._config

resulting_state_vectors = []

for state_vector, subprogram in zip(state._batch_state_vectors, subprograms):
small_state = PureFockState(d=d, calculator=calculator, config=config)

small_state._state_vector = state_vector

resulting_state_vectors.append(
execute(subprogram, initial_state=small_state).state._state_vector
)

state._apply_separate_state_vectors(resulting_state_vectors)

return Result(state=state)
Loading

0 comments on commit 8a9a911

Please sign in to comment.