Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added mm_einsum #515

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions mrmustard/physics/mm_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from numba import njit
import numpy as np
import itertools

Check warning on line 18 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L16-L18

Added lines #L16 - L18 were not covered by tests


@njit
def _CV_flops(nA: int, nB: int, m: int) -> int:

Check warning on line 22 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L21-L22

Added lines #L21 - L22 were not covered by tests
"""Calculate the cost of contracting two tensors with CV indices.
Args:
nA: Number of CV indices in the first tensor
nB: Number of CV indices in the second tensor
m: Number of CV indices involved in the contraction
"""
cost = (

Check warning on line 29 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L29

Added line #L29 was not covered by tests
m * m * m # M inverse
+ (m + 1) * m * nA # left matmul
+ (m + 1) * m * nB # right matmul
+ (m + 1) * m # addition
+ m * m * m
) # determinant of M
return cost

Check warning on line 36 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L36

Added line #L36 was not covered by tests


def new_indices_and_flops(

Check warning on line 39 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L39

Added line #L39 was not covered by tests
ziofil marked this conversation as resolved.
Show resolved Hide resolved
idx1: frozenset[int], idx2: frozenset[int], fock_size_dict: dict[int, int]
) -> tuple[frozenset[int], int]:
"""Calculate the cost of contracting two tensors with mixed CV and Fock indices.

This function computes both the surviving indices and the computational cost (in FLOPS)
of contracting two tensors that contain a mixture of continuous-variable (CV) and
Fock-space indices.

Args:
idx1: Set of indices for the first tensor. CV indices are integers not present
in fock_size_dict.
idx2: Set of indices for the second tensor. CV indices are integers not present
in fock_size_dict.
fock_size_dict: Dict mapping Fock index labels to their dimensions. Any index
not in this dict is treated as a CV index.

Returns:
tuple[frozenset[int], int]: A tuple containing:
- frozenset of indices that survive the contraction
- total computational cost in FLOPS (including CV operations,
Fock contractions, and potential decompositions)

Example:
>>> idx1 = frozenset({0, 1}) # 0 is CV, 1 is Fock
>>> idx2 = frozenset({1, 2}) # 2 is Fock
>>> fock_size_dict = {1: 2, 2: 3}
>>> new_indices_and_flops(idx1, idx2, fock_size_dict)
(frozenset({0, 2}), 9) # Example values
"""

# Calculate index sets for contraction
contracted_indices = idx1 & idx2 # Indices that get contracted away
remaining_indices = idx1 ^ idx2 # Indices that remain after contraction
all_fock_indices = set(fock_size_dict.keys())

Check warning on line 73 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L71-L73

Added lines #L71 - L73 were not covered by tests

# Count CV and get Fock shapes
num_cv_contracted = len(contracted_indices - all_fock_indices)
fock_contracted_shape = [fock_size_dict[idx] for idx in contracted_indices & all_fock_indices]
fock_remaining_shape = [fock_size_dict[idx] for idx in remaining_indices & all_fock_indices]

Check warning on line 78 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L76-L78

Added lines #L76 - L78 were not covered by tests

# Calculate flops
cv_flops = _CV_flops(

Check warning on line 81 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L81

Added line #L81 was not covered by tests
nA=len(idx1) - num_cv_contracted, nB=len(idx2) - num_cv_contracted, m=num_cv_contracted
)

if len(fock_contracted_shape) > 0:
fock_flops = np.prod(fock_contracted_shape) * np.prod(fock_remaining_shape)

Check warning on line 86 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L85-L86

Added lines #L85 - L86 were not covered by tests
else:
fock_flops = 0

Check warning on line 88 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L88

Added line #L88 was not covered by tests

# Try decomposing the remaining indices
new_indices, decomp_flops = attempt_decomposition(remaining_indices, fock_size_dict)

Check warning on line 91 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L91

Added line #L91 was not covered by tests

# pretending that we call the ansatz with the remaining indices
call_flops = np.prod([fock_size_dict[idx] for idx in new_indices if idx in fock_size_dict])

Check warning on line 94 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L94

Added line #L94 was not covered by tests

total_flops = int(cv_flops + fock_flops + decomp_flops + call_flops)
return new_indices, total_flops

Check warning on line 97 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L96-L97

Added lines #L96 - L97 were not covered by tests


def attempt_decomposition(

Check warning on line 100 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L100

Added line #L100 was not covered by tests
indices: set[int], fock_size_dict: dict[int, int]
) -> tuple[set[int], int]:
"""Attempt to reduce the number of indices by combining Fock indices when possible.
Only possible if there is only one CV index and multiple Fock indices.

Args:
indices: Set of indices to potentially decompose
fock_size_dict: Dictionary mapping indices to their sizes

Returns:
Tuple of (decomposed indices, cost of decomposition)
"""
fock_indices_shape = [fock_size_dict[idx] for idx in indices if idx in fock_size_dict]
cv_indices = [idx for idx in indices if idx not in fock_size_dict]

Check warning on line 114 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L113-L114

Added lines #L113 - L114 were not covered by tests

if len(cv_indices) == 1 and len(fock_indices_shape) > 1:
new_index = max(fock_size_dict) + 1 # Create new index with size = sum of Fock index sizes
decomposed_indices = {cv_indices[0], new_index}
fock_size_dict[new_index] = sum(fock_indices_shape)
decomp_flops = np.prod(fock_indices_shape)
return frozenset(decomposed_indices), decomp_flops
return frozenset(indices), 0

Check warning on line 122 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L116-L122

Added lines #L116 - L122 were not covered by tests


def optimal(

Check warning on line 125 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L125

Added line #L125 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the optimal function, consider adding a timeout mechanism or a maximum iteration limit to prevent potential infinite loops or excessive runtime for very large input sets. This would improve the robustness of the function. [important]

inputs: list[frozenset[int]],
fock_size_dict: dict[int, int],
info: bool = False,
) -> list[tuple[int, int]]:
"""Find the optimal contraction path for a mixed CV-Fock tensor network.

This function performs an exhaustive search over all possible contraction orders
for a tensor network containing both continuous-variable (CV) and Fock-space tensors.
It uses a depth-first recursive strategy to find the sequence of pairwise contractions
that minimizes the total computational cost (FLOPS).

CV indices are represented by integers not present in fock_size_dict, while Fock
indices must be keys in fock_size_dict. The algorithm caches intermediate results,
skips outer products (contractions between tensors with no shared indices), and
prunes the search when partial paths exceed the current best cost.

Args:
inputs: List of index sets representing tensor indices
fock_size_dict: Mapping from Fock index labels to dimensions
info: If True, prints cache size diagnostics

Returns:
tuple[tuple[int, int], ...]: The optimal contraction path as a sequence of pairs.
Each pair (i, j) indicates that tensors at positions i and j should be
contracted together. The resulting tensor is placed at position len(inputs).

Example:
>>> inputs = [frozenset({0, 1}), frozenset({1, 2}), frozenset({2, 3})]
>>> fock_size_dict = {1: 2, 2: 2} # indices 0 and 3 are CV indices
>>> optimal(inputs, fock_size_dict)
((0, 1), (2, 3))

Reference:
Based on the optimal path finder in opt_einsum:
https://github.com/dgasmith/opt_einsum/blob/master/opt_einsum/paths.py
"""
best_flops: int = float("inf")
best_path: tuple[tuple[int, int], ...] = ()
result_cache: dict[tuple[frozenset[int], frozenset[int]], tuple[frozenset[int], int]] = {}

Check warning on line 164 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L162-L164

Added lines #L162 - L164 were not covered by tests

def _optimal_iterate(path, remaining, inputs, flops):

Check warning on line 166 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L166

Added line #L166 was not covered by tests
nonlocal best_flops
nonlocal best_path

if len(remaining) == 1:
best_flops = flops
best_path = path
return

Check warning on line 173 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L170-L173

Added lines #L170 - L173 were not covered by tests

# check all remaining paths
for i, j in itertools.combinations(remaining, 2):
if i > j:
i, j = j, i

Check warning on line 178 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L176-L178

Added lines #L176 - L178 were not covered by tests

# skip outer products
if not inputs[i] & inputs[j]:
continue

Check warning on line 182 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L181-L182

Added lines #L181 - L182 were not covered by tests

key = (inputs[i], inputs[j])
try:
new_indices, flops_ij = result_cache[key]
except KeyError:
new_indices, flops_ij = result_cache[key] = new_indices_and_flops(

Check warning on line 188 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L184-L188

Added lines #L184 - L188 were not covered by tests
*key, fock_size_dict
)

# sieve based on current best flops
new_flops = flops + flops_ij
if new_flops >= best_flops:
continue

Check warning on line 195 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L193-L195

Added lines #L193 - L195 were not covered by tests

# add contraction and recurse into all remaining
_optimal_iterate(

Check warning on line 198 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L198

Added line #L198 was not covered by tests
path=path + ((i, j),),
inputs=inputs + (new_indices,),
remaining=remaining - {i, j} | {len(inputs)},
flops=new_flops,
)

_optimal_iterate(

Check warning on line 205 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L205

Added line #L205 was not covered by tests
path=(), inputs=tuple(map(frozenset, inputs)), remaining=set(range(len(inputs))), flops=0
)

if info:
print("len(fock_size_dict)", len(fock_size_dict), "len(result_cache)", len(result_cache))
return best_path

Check warning on line 211 in mrmustard/physics/mm_einsum.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/physics/mm_einsum.py#L209-L211

Added lines #L209 - L211 were not covered by tests
Loading