Skip to content

Commit

Permalink
feat(fock/pure): get_tensor_representation
Browse files Browse the repository at this point in the history
`PureFockSimulator` calculates with a global cutoff (i.e., cutoff
in the total number of particles), therefore it does not represent
the state vector as a tensor. `get_tensor_representation` is
created to reshape the current state vector representation into a
tensor.
  • Loading branch information
Kolarovszki committed Feb 9, 2024
1 parent d8afae6 commit c450c08
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 9 deletions.
7 changes: 4 additions & 3 deletions piquasso/_backends/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def assign(self, array, index, value):

return array

def scatter(self, indices, updates, dim):
embedded_matrix = np.zeros((dim,) * 2, dtype=complex)
composite_index = np.array(indices)[:, 0], np.array(indices)[:, 1]
def scatter(self, indices, updates, shape):
embedded_matrix = np.zeros(shape, dtype=complex)
indices_array = np.array(indices)
composite_index = tuple([indices_array[:, i] for i in range(len(shape))])

embedded_matrix[composite_index] = np.array(updates)

Expand Down
18 changes: 18 additions & 0 deletions piquasso/_backends/fock/pure/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,21 @@ def mean_photon_number(self):
accumulator += number * self._np.abs(self._state_vector[index]) ** 2

return accumulator

def get_tensor_representation(self):
calculator = self._calculator
cutoff = self._config.cutoff
d = self.d

indices = []
updates = []

for index, number in enumerate(self._space):
indices.append(list(number))
updates.append(self._state_vector[index])

return calculator.scatter(
indices,
updates,
[cutoff] * d,
)
6 changes: 3 additions & 3 deletions piquasso/_backends/tensorflow/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def block(self, arrays):

return self.np.stack(output)

def scatter(self, indices, updates, dim):
return self._tf.scatter_nd(indices, updates, (dim, dim))
def scatter(self, indices, updates, shape):
return self._tf.scatter_nd(indices, updates, shape)

def embed_in_identity(self, matrix, indices, dim):
tf_indices = []
Expand All @@ -100,7 +100,7 @@ def embed_in_identity(self, matrix, indices, dim):
tf_indices.append(diagonal_index)
updates.append(1.0)

return self.scatter(tf_indices, updates, dim)
return self.scatter(tf_indices, updates, (dim, dim))

def _funm(self, matrix, func):
eigenvalues, U = self._tf.linalg.eig(matrix)
Expand Down
2 changes: 1 addition & 1 deletion piquasso/_math/fock.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def embed_matrix(
indices.append(embedded_index)
updates.append(matrix[index][0])

return self._calculator.scatter(indices, updates, dim=self.cardinality)
return self._calculator.scatter(indices, updates, shape=(self.cardinality,) * 2)

def get_linear_fock_operator(
self,
Expand Down
2 changes: 1 addition & 1 deletion piquasso/api/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def loop_hafnian(
def assign(self, array, index, value):
raise NotImplementedCalculation()

def scatter(self, indices, updates, dim):
def scatter(self, indices, updates, shape):
raise NotImplementedCalculation()

def embed_in_identity(self, matrix, indices, dim):
Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_BaseCalculator_raises_NotImplementedCalculation_for_scatter(
empty_calculator,
):
with pytest.raises(NotImplementedCalculation):
empty_calculator.scatter(indices=[], updates=[], dim=2)
empty_calculator.scatter(indices=[], updates=[], shape=(3, 3))


def test_BaseCalculator_raises_NotImplementedCalculation_for_embed_in_identity(
Expand Down
29 changes: 29 additions & 0 deletions tests/backends/fock/pure/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,32 @@ def test_normalize_if_disabled_in_Config():
norm = state.norm

assert not np.isclose(norm, 1.0)


def test_PureFockState_get_tensor_representation():
d = 2
cutoff = 3

with pq.Program() as program:
pq.Q() | pq.StateVector([0, 1]) / 2

pq.Q() | pq.StateVector([0, 2]) / 2
pq.Q() | pq.StateVector([2, 0]) / np.sqrt(2)

simulator = pq.PureFockSimulator(d=d, config=pq.Config(cutoff=cutoff))
state = simulator.execute(program).state

state_tensor = state.get_tensor_representation()

assert state_tensor.shape == (3,) * 2

assert np.allclose(
state_tensor,
np.array(
[
[0.0 + 0.0j, 0.5 + 0.0j, 0.5 + 0.0j],
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[0.70710678 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
]
),
)

0 comments on commit c450c08

Please sign in to comment.