Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Refactor #521

Open
wants to merge 53 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b58afcd
init
apchytr Nov 6, 2024
dc9376c
prog
apchytr Nov 8, 2024
a4e7f75
docs
apchytr Nov 8, 2024
96403ec
some tests passing
apchytr Nov 11, 2024
a675779
progress up to and
apchytr Nov 12, 2024
0ae4137
prog
apchytr Nov 12, 2024
e4f9ecc
prog
apchytr Nov 12, 2024
5abeaa8
fix
apchytr Nov 12, 2024
a52e517
getitem initial
apchytr Nov 20, 2024
fad7fd4
IndexError messages
apchytr Nov 20, 2024
43d2217
simplify working
apchytr Nov 21, 2024
c390929
moving to using np arrays
apchytr Nov 22, 2024
f6c11d2
almost all polyexp tests passing
apchytr Nov 26, 2024
2a825bf
polyexp passing
apchytr Nov 26, 2024
8d12932
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Nov 26, 2024
00276a0
initial tests passing
apchytr Nov 27, 2024
c33baad
fixing reduce
apchytr Nov 27, 2024
5edf496
up to a single tf test
apchytr Nov 27, 2024
b1bbf29
fix test
apchytr Nov 27, 2024
a389643
update
apchytr Nov 27, 2024
fb13b18
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Dec 2, 2024
b6150b6
gitignore
apchytr Dec 2, 2024
a4418ca
contract
apchytr Dec 2, 2024
b496720
slight update
apchytr Dec 2, 2024
385532f
oops
apchytr Dec 2, 2024
d6c4f62
cleanup
apchytr Dec 2, 2024
5011f1a
ArrayAnsatz
apchytr Dec 3, 2024
bd8adc9
doc
apchytr Dec 3, 2024
e8efc8c
bargmann triple
apchytr Dec 3, 2024
8c0c013
fixed
apchytr Dec 3, 2024
b57a6e5
doc
apchytr Dec 3, 2024
985de0d
updates
apchytr Dec 4, 2024
cba0b37
batch tests
apchytr Dec 4, 2024
2673213
positional
apchytr Dec 4, 2024
9cd0245
cleanup and test
apchytr Dec 4, 2024
49a62c8
codefactor
apchytr Dec 4, 2024
b10e0d9
getitem tests
apchytr Dec 4, 2024
78be2ce
cov
apchytr Dec 4, 2024
1a80431
cov
apchytr Dec 5, 2024
654e44a
index error
apchytr Dec 5, 2024
0f6f8bd
batch shape validation
apchytr Dec 5, 2024
2bd37f1
rem
apchytr Dec 5, 2024
f0bce35
skip tf
apchytr Dec 5, 2024
9a5679c
codefactor
apchytr Dec 5, 2024
3821128
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Dec 10, 2024
5156640
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Dec 10, 2024
8491ead
rem conjugate
apchytr Dec 10, 2024
39120d4
progress
apchytr Dec 12, 2024
f30cdb5
fix
apchytr Dec 13, 2024
bdceb0b
rep cov
apchytr Dec 16, 2024
d1495c6
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Dec 16, 2024
5d7a51d
nuke apply in favor of get_fn
apchytr Dec 16, 2024
435ec90
pragma no cover
apchytr Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor:
# Find where all the bras and kets are so they can be conjugated appropriately
conjugates = [i not in self.wires.ket.indices for i in range(len(self.wires.indices))]
quad_basis = math.sum(
[quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axes=[0]
[quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axis=[0]
)
return quad_basis

Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/circuit_components_utils/trace_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitCompon
ansatz = other.ansatz
wires = other.wires
elif not ket or not bra:
ansatz = other.ansatz.conj[idx_z] @ other.ansatz[idx_z]
ansatz = other.ansatz.conj.contract(other.ansatz, idx_z, idx_z)
wires, _ = (other.wires.adjoint @ other.wires)[0] @ self.wires
else:
ansatz = other.ansatz.trace(idx_z, idx_zconj)
Expand Down
12 changes: 5 additions & 7 deletions mrmustard/lab_dev/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def from_fock(
modes: Sequence[int],
array: ComplexTensor,
name: str | None = None,
batched: bool = False,
) -> State:
r"""
Initializes a state of type ``cls`` from an array parametrizing the
Expand All @@ -218,17 +217,16 @@ def from_fock(

>>> modes = [0]
>>> array = Coherent(modes, x=0.1).to_fock().ansatz.array
>>> coh = Ket.from_fock(modes, array, batched=True)
>>> coh = Ket.from_fock(modes, array)

>>> assert coh.modes == modes
>>> assert coh.ansatz == ArrayAnsatz(array, batched=True)
>>> assert coh.ansatz == ArrayAnsatz(array)
>>> assert isinstance(coh, Ket)

Args:
modes: The modes of this state.
array: The Fock array.
name: The name of this state.
batched: Whether the given array is batched.

Returns:
A state.
Expand Down Expand Up @@ -385,7 +383,7 @@ def visualize_2d(
shape = [max(min_shape, d) for d in self.auto_shape()]
state = self.to_fock(tuple(shape))
state = state.dm()
dm = math.sum(state.ansatz.array, axes=[0])
dm = math.sum(state.ansatz.array, axis=[0])

x, prob_x = quadrature_distribution(dm)
p, prob_p = quadrature_distribution(dm, np.pi / 2)
Expand Down Expand Up @@ -501,7 +499,7 @@ def visualize_3d(
shape = [max(min_shape, d) for d in self.auto_shape()]
state = self.to_fock(tuple(shape))
state = state.dm()
dm = math.sum(state.ansatz.array, axes=[0])
dm = math.sum(state.ansatz.array, axis=[0])

xvec = np.linspace(*xbounds, resolution)
pvec = np.linspace(*pbounds, resolution)
Expand Down Expand Up @@ -575,7 +573,7 @@ def visualize_dm(
raise ValueError("DM visualization not available for multi-mode states.")
state = self.to_fock(cutoff)
state = state.dm()
dm = math.sum(state.ansatz.array, axes=[0])
dm = math.sum(state.ansatz.array, axis=[0])

fig = go.Figure(
data=go.Heatmap(z=abs(dm), colorscale="viridis", name="abs(ρ)", showscale=False)
Expand Down
6 changes: 3 additions & 3 deletions mrmustard/lab_dev/states/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from mrmustard.math.lattice.strategies.vanilla import autoshape_numba
from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz
from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho
from mrmustard.physics.batches import Batch
from mrmustard.physics.gaussian_integrals import complex_gaussian_integral_2
from mrmustard.physics.representations import Representation
from mrmustard.physics.wires import Wires
Expand Down Expand Up @@ -121,11 +122,10 @@ def from_bargmann(
def from_fock(
cls,
modes: Sequence[int],
array: ComplexTensor,
array: ComplexTensor | Batch[ComplexTensor],
name: str | None = None,
batched: bool = False,
) -> State:
return DM.from_ansatz(modes, ArrayAnsatz(array, batched), name)
return DM.from_ansatz(modes, ArrayAnsatz(array), name)

@classmethod
def from_ansatz(
Expand Down
7 changes: 3 additions & 4 deletions mrmustard/lab_dev/states/ket.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from mrmustard.math.lattice.strategies.vanilla import autoshape_numba
from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz
from mrmustard.physics.bargmann_utils import wigner_to_bargmann_psi
from mrmustard.physics.batches import Batch
from mrmustard.physics.gaussian import purity
from mrmustard.physics.representations import Representation
from mrmustard.physics.wires import Wires
Expand All @@ -37,7 +38,6 @@
ComplexTensor,
RealVector,
Scalar,
Batch,
)

from .base import State, _validate_operator, OperatorType
Expand Down Expand Up @@ -100,11 +100,10 @@ def from_bargmann(
def from_fock(
cls,
modes: Sequence[int],
array: ComplexTensor,
array: ComplexTensor | Batch[ComplexTensor],
name: str | None = None,
batched: bool = False,
) -> State:
return Ket.from_ansatz(modes, ArrayAnsatz(array, batched), name)
return Ket.from_ansatz(modes, ArrayAnsatz(array), name)

@classmethod
def from_ansatz(
Expand Down
8 changes: 4 additions & 4 deletions mrmustard/lab_dev/transformations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def symplectic(self):
Returns the symplectic matrix that corresponds to this unitary
"""
batch_size = self.ansatz.batch_size
return [au2Symplectic(self.ansatz.A[batch, :, :]) for batch in range(batch_size)]
return [au2Symplectic(self.ansatz.A[batch]) for batch in range(batch_size)]

@classmethod
def from_bargmann(
Expand Down Expand Up @@ -380,7 +380,7 @@ def is_CP(self) -> bool:
raise ValueError(
"Physicality conditions are not implemented for batch dimension larger than 1."
)
A = self.ansatz.A
A = self.ansatz.A.data
m = A.shape[-1] // 2
gamma_A = A[0, :m, m:]

Expand All @@ -396,7 +396,7 @@ def is_TP(self) -> bool:
r"""
Whether this channel is trace preserving (TP).
"""
A = self.ansatz.A
A = self.ansatz.A.data
m = A.shape[-1] // 2
gamma_A = A[0, :m, m:]
lambda_A = A[0, m:, m:]
Expand Down Expand Up @@ -521,7 +521,7 @@ def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Channel:
U = Unitary.random(range(3 * m), max_r)
u_psi = Vacuum(range(2 * m)) >> U
A = u_psi.ansatz
kraus = A.conj[range(2 * m)] @ A[range(2 * m)]
kraus = A.conj.contract(A, range(2 * m), range(2 * m))
return Channel.from_bargmann(modes, modes, kraus.triple)

def __rshift__(self, other: CircuitComponent) -> CircuitComponent:
Expand Down
11 changes: 7 additions & 4 deletions mrmustard/lab_dev/transformations/dgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from mrmustard import math

from .base import Unitary
from ...physics.representations import Representation
from ...physics.ansatz import PolyExpAnsatz, ArrayAnsatz
from ...physics.batches import Batch
from ...physics.representations import Representation
from ...physics import triples, fock_utils
from ..utils import make_parameter, reshape_params

Expand Down Expand Up @@ -103,7 +104,9 @@ def __init__(
),
).representation

def fock_array(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTensor:
def fock_array(
self, shape: int | Sequence[int] = None, batched=False
) -> ComplexTensor | Batch[ComplexTensor]:
r"""
Returns the unitary representation of the Displacement gate using the Laguerre polynomials.
If the shape is not given, it defaults to the ``auto_shape`` of the component if it is
Expand Down Expand Up @@ -144,11 +147,11 @@ def fock_array(self, shape: int | Sequence[int] = None, batched=False) -> Comple
)
else:
array = fock_utils.displacement(x[0], y[0], shape=shape)
arrays = math.expand_dims(array, 0) if batched else array
arrays = Batch(math.expand_dims(array, 0)) if batched else array
return arrays

def to_fock(self, shape: int | Sequence[int] | None = None) -> Dgate:
fock = ArrayAnsatz(self.fock_array(shape, batched=True), batched=True)
fock = ArrayAnsatz(self.fock_array(shape, batched=True))
fock._original_abc_data = self.ansatz.triple
ret = self._getitem_builtin(self.modes)
ret._representation = Representation(fock, self.wires)
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/transformations/phasenoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ def __custom_rrshift__(self, other: CircuitComponent) -> CircuitComponent:
* self.phase_stdev.value**2
)
array *= phase_factors
return CircuitComponent(Representation(ArrayAnsatz(array, False), other.wires), self.name)
return CircuitComponent(Representation(ArrayAnsatz(array), other.wires), self.name)
Loading
Loading