diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 944b6e5e6..05ddff3b4 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -336,7 +336,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor: # Find where all the bras and kets are so they can be conjugated appropriately conjugates = [i not in self.wires.ket.indices for i in range(len(self.wires.indices))] quad_basis = math.sum( - [quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axes=[0] + [quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axis=[0] ) return quad_basis diff --git a/mrmustard/lab_dev/circuit_components_utils/trace_out.py b/mrmustard/lab_dev/circuit_components_utils/trace_out.py index 07cc83801..b468d898b 100644 --- a/mrmustard/lab_dev/circuit_components_utils/trace_out.py +++ b/mrmustard/lab_dev/circuit_components_utils/trace_out.py @@ -86,7 +86,7 @@ def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitCompon ansatz = other.ansatz wires = other.wires elif not ket or not bra: - ansatz = other.ansatz.conj[idx_z] @ other.ansatz[idx_z] + ansatz = other.ansatz.conj.contract(other.ansatz, idx_z, idx_z) wires, _ = (other.wires.adjoint @ other.wires)[0] @ self.wires else: ansatz = other.ansatz.trace(idx_z, idx_zconj) diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index fa6558d03..c81f7ff18 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -204,7 +204,6 @@ def from_fock( modes: Sequence[int], array: ComplexTensor, name: str | None = None, - batched: bool = False, ) -> State: r""" Initializes a state of type ``cls`` from an array parametrizing the @@ -218,17 +217,16 @@ def from_fock( >>> modes = [0] >>> array = Coherent(modes, x=0.1).to_fock().ansatz.array - >>> coh = Ket.from_fock(modes, array, batched=True) + >>> coh = Ket.from_fock(modes, array) >>> assert coh.modes == modes - >>> assert coh.ansatz == ArrayAnsatz(array, batched=True) + >>> assert coh.ansatz == ArrayAnsatz(array) >>> assert isinstance(coh, Ket) Args: modes: The modes of this state. array: The Fock array. name: The name of this state. - batched: Whether the given array is batched. Returns: A state. @@ -385,7 +383,7 @@ def visualize_2d( shape = [max(min_shape, d) for d in self.auto_shape()] state = self.to_fock(tuple(shape)) state = state.dm() - dm = math.sum(state.ansatz.array, axes=[0]) + dm = math.sum(state.ansatz.array, axis=[0]) x, prob_x = quadrature_distribution(dm) p, prob_p = quadrature_distribution(dm, np.pi / 2) @@ -501,7 +499,7 @@ def visualize_3d( shape = [max(min_shape, d) for d in self.auto_shape()] state = self.to_fock(tuple(shape)) state = state.dm() - dm = math.sum(state.ansatz.array, axes=[0]) + dm = math.sum(state.ansatz.array, axis=[0]) xvec = np.linspace(*xbounds, resolution) pvec = np.linspace(*pbounds, resolution) @@ -575,7 +573,7 @@ def visualize_dm( raise ValueError("DM visualization not available for multi-mode states.") state = self.to_fock(cutoff) state = state.dm() - dm = math.sum(state.ansatz.array, axes=[0]) + dm = math.sum(state.ansatz.array, axis=[0]) fig = go.Figure( data=go.Heatmap(z=abs(dm), colorscale="viridis", name="abs(ρ)", showscale=False) diff --git a/mrmustard/lab_dev/states/dm.py b/mrmustard/lab_dev/states/dm.py index 5b97a0553..cd87133bf 100644 --- a/mrmustard/lab_dev/states/dm.py +++ b/mrmustard/lab_dev/states/dm.py @@ -28,6 +28,7 @@ from mrmustard.math.lattice.strategies.vanilla import autoshape_numba from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho +from mrmustard.physics.batches import Batch from mrmustard.physics.gaussian_integrals import complex_gaussian_integral_2 from mrmustard.physics.representations import Representation from mrmustard.physics.wires import Wires @@ -121,11 +122,10 @@ def from_bargmann( def from_fock( cls, modes: Sequence[int], - array: ComplexTensor, + array: ComplexTensor | Batch[ComplexTensor], name: str | None = None, - batched: bool = False, ) -> State: - return DM.from_ansatz(modes, ArrayAnsatz(array, batched), name) + return DM.from_ansatz(modes, ArrayAnsatz(array), name) @classmethod def from_ansatz( diff --git a/mrmustard/lab_dev/states/ket.py b/mrmustard/lab_dev/states/ket.py index b47e706bd..67c23c320 100644 --- a/mrmustard/lab_dev/states/ket.py +++ b/mrmustard/lab_dev/states/ket.py @@ -28,6 +28,7 @@ from mrmustard.math.lattice.strategies.vanilla import autoshape_numba from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz from mrmustard.physics.bargmann_utils import wigner_to_bargmann_psi +from mrmustard.physics.batches import Batch from mrmustard.physics.gaussian import purity from mrmustard.physics.representations import Representation from mrmustard.physics.wires import Wires @@ -37,7 +38,6 @@ ComplexTensor, RealVector, Scalar, - Batch, ) from .base import State, _validate_operator, OperatorType @@ -100,11 +100,10 @@ def from_bargmann( def from_fock( cls, modes: Sequence[int], - array: ComplexTensor, + array: ComplexTensor | Batch[ComplexTensor], name: str | None = None, - batched: bool = False, ) -> State: - return Ket.from_ansatz(modes, ArrayAnsatz(array, batched), name) + return Ket.from_ansatz(modes, ArrayAnsatz(array), name) @classmethod def from_ansatz( diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index 423cec247..f258ae68a 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -201,7 +201,7 @@ def symplectic(self): Returns the symplectic matrix that corresponds to this unitary """ batch_size = self.ansatz.batch_size - return [au2Symplectic(self.ansatz.A[batch, :, :]) for batch in range(batch_size)] + return [au2Symplectic(self.ansatz.A[batch]) for batch in range(batch_size)] @classmethod def from_bargmann( @@ -380,7 +380,7 @@ def is_CP(self) -> bool: raise ValueError( "Physicality conditions are not implemented for batch dimension larger than 1." ) - A = self.ansatz.A + A = self.ansatz.A.data m = A.shape[-1] // 2 gamma_A = A[0, :m, m:] @@ -396,7 +396,7 @@ def is_TP(self) -> bool: r""" Whether this channel is trace preserving (TP). """ - A = self.ansatz.A + A = self.ansatz.A.data m = A.shape[-1] // 2 gamma_A = A[0, :m, m:] lambda_A = A[0, m:, m:] @@ -521,7 +521,7 @@ def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Channel: U = Unitary.random(range(3 * m), max_r) u_psi = Vacuum(range(2 * m)) >> U A = u_psi.ansatz - kraus = A.conj[range(2 * m)] @ A[range(2 * m)] + kraus = A.conj.contract(A, range(2 * m), range(2 * m)) return Channel.from_bargmann(modes, modes, kraus.triple) def __rshift__(self, other: CircuitComponent) -> CircuitComponent: diff --git a/mrmustard/lab_dev/transformations/dgate.py b/mrmustard/lab_dev/transformations/dgate.py index aa5f76a0f..ed1ead9ab 100644 --- a/mrmustard/lab_dev/transformations/dgate.py +++ b/mrmustard/lab_dev/transformations/dgate.py @@ -24,8 +24,9 @@ from mrmustard import math from .base import Unitary -from ...physics.representations import Representation from ...physics.ansatz import PolyExpAnsatz, ArrayAnsatz +from ...physics.batches import Batch +from ...physics.representations import Representation from ...physics import triples, fock_utils from ..utils import make_parameter, reshape_params @@ -103,7 +104,9 @@ def __init__( ), ).representation - def fock_array(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTensor: + def fock_array( + self, shape: int | Sequence[int] = None, batched=False + ) -> ComplexTensor | Batch[ComplexTensor]: r""" Returns the unitary representation of the Displacement gate using the Laguerre polynomials. If the shape is not given, it defaults to the ``auto_shape`` of the component if it is @@ -144,11 +147,11 @@ def fock_array(self, shape: int | Sequence[int] = None, batched=False) -> Comple ) else: array = fock_utils.displacement(x[0], y[0], shape=shape) - arrays = math.expand_dims(array, 0) if batched else array + arrays = Batch(math.expand_dims(array, 0)) if batched else array return arrays def to_fock(self, shape: int | Sequence[int] | None = None) -> Dgate: - fock = ArrayAnsatz(self.fock_array(shape, batched=True), batched=True) + fock = ArrayAnsatz(self.fock_array(shape, batched=True)) fock._original_abc_data = self.ansatz.triple ret = self._getitem_builtin(self.modes) ret._representation = Representation(fock, self.wires) diff --git a/mrmustard/lab_dev/transformations/phasenoise.py b/mrmustard/lab_dev/transformations/phasenoise.py index f990ae53c..67e18a3c8 100644 --- a/mrmustard/lab_dev/transformations/phasenoise.py +++ b/mrmustard/lab_dev/transformations/phasenoise.py @@ -83,4 +83,4 @@ def __custom_rrshift__(self, other: CircuitComponent) -> CircuitComponent: * self.phase_stdev.value**2 ) array *= phase_factors - return CircuitComponent(Representation(ArrayAnsatz(array, False), other.wires), self.name) + return CircuitComponent(Representation(ArrayAnsatz(array), other.wires), self.name) diff --git a/mrmustard/math/backend_manager.py b/mrmustard/math/backend_manager.py index ece784651..df909c307 100644 --- a/mrmustard/math/backend_manager.py +++ b/mrmustard/math/backend_manager.py @@ -98,17 +98,13 @@ def __init__(self) -> None: # binding types and decorators of numpy backend self._bind() - def _apply(self, fn: str, args: Sequence[Any] | None = ()) -> Any: - r""" - Applies a function ``fn`` from the backend in use to the given ``args``. - """ + def _get_fn(self, fn_str: str) -> Callable: try: - attr = getattr(self.backend, fn) + return getattr(self.backend, fn_str) except AttributeError: - msg = f"Function ``{fn}`` not implemented for backend ``{self.backend_name}``." + msg = f"Function ``{fn_str}`` not implemented for backend ``{self.backend_name}``." # pylint: disable=raise-missing-from raise NotImplementedError(msg) - return attr(*args) def _bind(self) -> None: r""" @@ -199,7 +195,11 @@ def abs(self, array: Tensor) -> Tensor: Returns: The absolute value of the given ``array``. """ - return self._apply("abs", (array,)) + fn = self._get_fn("abs") + try: + return array.__array_ufunc__(fn, "__call__", array) + except AttributeError: + return fn(array) def allclose(self, array1: Tensor, array2: Tensor, atol=1e-9) -> bool: r""" @@ -218,7 +218,8 @@ def allclose(self, array1: Tensor, array2: Tensor, atol=1e-9) -> bool: Raises: ValueError: If the shape of the two arrays do not match. """ - return self._apply("allclose", (array1, array2, atol)) + fn = self._get_fn("allclose") + return fn(array1, array2, atol) def any(self, array: Tensor) -> bool: r"""Returns ``True`` if any element of array is ``True``, ``False`` otherwise. @@ -229,7 +230,8 @@ def any(self, array: Tensor) -> bool: Returns: ``True`` if any element of array is ``True``, ``False`` otherwise. """ - return self._apply("any", (array,)) + fn = self._get_fn("any") + return fn(array) def arange(self, start: int, limit: int = None, delta: int = 1, dtype: Any = None) -> Tensor: r"""Returns an array of evenly spaced values within a given interval. @@ -238,13 +240,13 @@ def arange(self, start: int, limit: int = None, delta: int = 1, dtype: Any = Non start: The start of the interval. limit: The end of the interval. delta: The step size. - dtype: The dtype of the returned array. + dtype: The dtype of the returned array. Defaults to float64. Returns: The array of evenly spaced values. """ - # NOTE: is float64 by default - return self._apply("arange", (start, limit, delta, dtype)) + fn = self._get_fn("arange") + return fn(start, limit, delta, dtype) def asnumpy(self, tensor: Tensor) -> Tensor: r"""Converts an array to a numpy array. @@ -253,9 +255,10 @@ def asnumpy(self, tensor: Tensor) -> Tensor: tensor: The tensor to convert. Returns: - The corresponidng numpy array. + The corresponding numpy array. """ - return self._apply("asnumpy", (tensor,)) + fn = self._get_fn("asnumpy") + return fn(tensor) def assign(self, tensor: Tensor, value: Tensor) -> Tensor: r"""Assigns value to tensor. @@ -267,7 +270,8 @@ def assign(self, tensor: Tensor, value: Tensor) -> Tensor: Returns: The tensor with value assigned """ - return self._apply("assign", (tensor, value)) + fn = self._get_fn("assign") + return fn(tensor, value) def astensor(self, array: Tensor, dtype=None): r"""Converts a numpy array to a tensor. @@ -280,7 +284,11 @@ def astensor(self, array: Tensor, dtype=None): Returns: The tensor with dtype. """ - return self._apply("astensor", (array, dtype)) + fn = self._get_fn("astensor") + try: + return array.__array_ufunc__(fn, "__call__", array, dtype=dtype) + except AttributeError: + return fn(array, dtype) def atleast_1d(self, array: Tensor, dtype=None) -> Tensor: r"""Returns an array with at least one dimension. @@ -293,7 +301,8 @@ def atleast_1d(self, array: Tensor, dtype=None) -> Tensor: Returns: The array with at least one dimension. """ - return self._apply("atleast_1d", (array, dtype)) + fn = self._get_fn("atleast_1d") + return fn(array, dtype) def atleast_2d(self, array: Tensor, dtype=None) -> Tensor: r"""Returns an array with at least two dimensions. @@ -306,7 +315,8 @@ def atleast_2d(self, array: Tensor, dtype=None) -> Tensor: Returns: The array with at least two dimensions. """ - return self._apply("atleast_2d", (array, dtype)) + fn = self._get_fn("atleast_2d") + return fn(array, dtype) def atleast_3d(self, array: Tensor, dtype=None) -> Tensor: r"""Returns an array with at least three dimensions by eventually inserting @@ -321,7 +331,8 @@ def atleast_3d(self, array: Tensor, dtype=None) -> Tensor: Returns: The array with at least three dimensions. """ - return self._apply("atleast_3d", (array, dtype)) + fn = self._get_fn("atleast_3d") + return fn(array, dtype) def block_diag(self, mat1: Matrix, mat2: Matrix) -> Matrix: r"""Returns a block diagonal matrix from the given matrices. @@ -333,7 +344,8 @@ def block_diag(self, mat1: Matrix, mat2: Matrix) -> Matrix: Returns: A block diagonal matrix from the given matrices. """ - return self._apply("block_diag", (mat1, mat2)) + fn = self._get_fn("block_diag") + return fn(mat1, mat2) def boolean_mask(self, tensor: Tensor, mask: Tensor) -> Tensor: """ @@ -346,7 +358,8 @@ def boolean_mask(self, tensor: Tensor, mask: Tensor) -> Tensor: Returns: A tensor based on the truth value of the boolean mask. """ - return self._apply("boolean_mask", (tensor, mask)) + fn = self._get_fn("boolean_mask") + return fn(tensor, mask) def block(self, blocks: list[list[Tensor]], axes=(-2, -1)) -> Tensor: r"""Returns a matrix made from the given blocks. @@ -358,7 +371,8 @@ def block(self, blocks: list[list[Tensor]], axes=(-2, -1)) -> Tensor: Returns: The matrix made of blocks. """ - return self._apply("block", (blocks, axes)) + fn = self._get_fn("block") + return fn(blocks, axes) def cast(self, array: Tensor, dtype=None) -> Tensor: r"""Casts ``array`` to ``dtype``. @@ -371,7 +385,8 @@ def cast(self, array: Tensor, dtype=None) -> Tensor: Returns: The array cast to dtype. """ - return self._apply("cast", (array, dtype)) + fn = self._get_fn("cast") + return fn(array, dtype) def clip(self, array: Tensor, a_min: float, a_max: float) -> Tensor: r"""Clips array to the interval ``[a_min, a_max]``. @@ -384,7 +399,8 @@ def clip(self, array: Tensor, a_min: float, a_max: float) -> Tensor: Returns: The clipped array. """ - return self._apply("clip", (array, a_min, a_max)) + fn = self._get_fn("clip") + return fn(array, a_min, a_max) def concat(self, values: Sequence[Tensor], axis: int) -> Tensor: r"""Concatenates values along the given axis. @@ -396,7 +412,8 @@ def concat(self, values: Sequence[Tensor], axis: int) -> Tensor: Returns: The concatenated values. """ - return self._apply("concat", (values, axis)) + fn = self._get_fn("concat") + return fn(values, axis) def conj(self, array: Tensor) -> Tensor: r"""The complex conjugate of array. @@ -407,9 +424,15 @@ def conj(self, array: Tensor) -> Tensor: Returns: The complex conjugate of the given ``array``. """ - return self._apply("conj", (array,)) + fn = self._get_fn("conj") + try: + return array.__array_ufunc__(fn, "__call__", array) + except AttributeError: + return fn(array) - def constraint_func(self, bounds: tuple[float | None, float | None]) -> Callable | None: + def constraint_func( + self, bounds: tuple[float | None, float | None] + ) -> Callable | None: # pragma: no cover r"""Returns a constraint function for the given bounds. A constraint function will clip the value to the interval given by the bounds. @@ -425,7 +448,8 @@ def constraint_func(self, bounds: tuple[float | None, float | None]) -> Callable Returns: The constraint function. """ - return self._apply("constraint_func", (bounds)) + fn = self._get_fn("constraint_func") + return fn(bounds) def convolution( self, @@ -445,7 +469,8 @@ def convolution( Returns: The convolved array. """ - return self._apply("convolution", (array, filters, padding, data_format)) + fn = self._get_fn("convolution") + return fn(array, filters, padding, data_format) def cos(self, array: Tensor) -> Tensor: r"""The cosine of an array. @@ -456,7 +481,8 @@ def cos(self, array: Tensor) -> Tensor: Returns: The cosine of ``array``. """ - return self._apply("cos", (array,)) + fn = self._get_fn("cos") + return fn(array) def cosh(self, array: Tensor) -> Tensor: r"""The hyperbolic cosine of array. @@ -467,7 +493,8 @@ def cosh(self, array: Tensor) -> Tensor: Returns: The hyperbolic cosine of ``array``. """ - return self._apply("cosh", (array,)) + fn = self._get_fn("cosh") + return fn(array) def det(self, matrix: Tensor) -> Tensor: r"""The determinant of matrix. @@ -478,7 +505,8 @@ def det(self, matrix: Tensor) -> Tensor: Returns: The determinant of ``matrix``. """ - return self._apply("det", (matrix,)) + fn = self._get_fn("det") + return fn(matrix) def diag(self, array: Tensor, k: int = 0) -> Tensor: r"""The array made by inserting the given array along the :math:`k`-th diagonal. @@ -490,7 +518,8 @@ def diag(self, array: Tensor, k: int = 0) -> Tensor: Returns: The array with ``array`` inserted into the ``k``-th diagonal. """ - return self._apply("diag", (array, k)) + fn = self._get_fn("diag") + return fn(array, k) def diag_part(self, array: Tensor, k: int = 0) -> Tensor: r"""The array of the main diagonal of array. @@ -502,7 +531,8 @@ def diag_part(self, array: Tensor, k: int = 0) -> Tensor: Returns: The array of the main diagonal of ``array``. """ - return self._apply("diag_part", (array, k)) + fn = self._get_fn("diag_part") + return fn(array, k) def eigvals(self, tensor: Tensor) -> Tensor: r"""The eigenvalues of a tensor. @@ -513,7 +543,8 @@ def eigvals(self, tensor: Tensor) -> Tensor: Returns: The eigenvalues of ``tensor``. """ - return self._apply("eigvals", (tensor,)) + fn = self._get_fn("eigvals") + return fn(tensor) def eigh(self, tensor: Tensor) -> Tensor: """ @@ -525,7 +556,8 @@ def eigh(self, tensor: Tensor) -> Tensor: Returns: The eigenvalues and eigenvectors of ``tensor``. """ - return self._apply("eigh", (tensor,)) + fn = self._get_fn("eigh") + return fn(tensor) def einsum(self, string: str, *tensors) -> Tensor: r"""The result of the Einstein summation convention on the tensors. @@ -537,7 +569,8 @@ def einsum(self, string: str, *tensors) -> Tensor: Returns: The result of the Einstein summation convention. """ - return self._apply("einsum", (string, *tensors)) + fn = self._get_fn("einsum") + return fn(string, *tensors) def exp(self, array: Tensor) -> Tensor: r"""The exponential of array element-wise. @@ -548,7 +581,8 @@ def exp(self, array: Tensor) -> Tensor: Returns: The exponential of array. """ - return self._apply("exp", (array,)) + fn = self._get_fn("exp") + return fn(array) def expand_dims(self, array: Tensor, axis: int) -> Tensor: r"""The array with an additional dimension inserted at the given axis. @@ -560,7 +594,8 @@ def expand_dims(self, array: Tensor, axis: int) -> Tensor: Returns: The array with an additional dimension inserted at the given axis. """ - return self._apply("expand_dims", (array, axis)) + fn = self._get_fn("expand_dims") + return fn(array, axis) def expm(self, matrix: Tensor) -> Tensor: r"""The matrix exponential of matrix. @@ -571,7 +606,8 @@ def expm(self, matrix: Tensor) -> Tensor: Returns: The exponential of ``matrix``. """ - return self._apply("expm", (matrix,)) + fn = self._get_fn("expm") + return fn(matrix) def eye(self, size: int, dtype=None) -> Tensor: r"""The identity matrix of size. @@ -584,7 +620,8 @@ def eye(self, size: int, dtype=None) -> Tensor: Returns: The identity matrix. """ - return self._apply("eye", (size, dtype)) + fn = self._get_fn("eye") + return fn(size, dtype) def eye_like(self, array: Tensor) -> Tensor: r"""The identity matrix of the same shape and dtype as array. @@ -595,7 +632,8 @@ def eye_like(self, array: Tensor) -> Tensor: Returns: The identity matrix. """ - return self._apply("eye_like", (array,)) + fn = self._get_fn("eye_like") + return fn(array) def from_backend(self, value: Any) -> bool: r"""Whether the given tensor is a tensor of the concrete backend. @@ -606,7 +644,8 @@ def from_backend(self, value: Any) -> bool: Returns: Whether given ``value`` is a tensor of the concrete backend. """ - return self._apply("from_backend", (value,)) + fn = self._get_fn("from_backend") + return fn(value) def gather(self, array: Tensor, indices: Batch[int], axis: int | None = None) -> Tensor: r"""The values of the array at the given indices. @@ -619,14 +658,8 @@ def gather(self, array: Tensor, indices: Batch[int], axis: int | None = None) -> Returns: The values of the array at the given indices. """ - return self._apply( - "gather", - ( - array, - indices, - axis, - ), - ) + fn = self._get_fn("gather") + return fn(array, indices, axis) def hermite_renormalized_batch( self, A: Tensor, B: Tensor, C: Tensor, shape: tuple[int] @@ -646,7 +679,8 @@ def hermite_renormalized_batch( Returns: The batched Hermite polynomial of given shape. """ - return self._apply("hermite_renormalized_batch", (A, B, C, shape)) + fn = self._get_fn("hermite_renormalized_batch") + return fn(A, B, C, shape) def hermite_renormalized_diagonal( self, A: Tensor, B: Tensor, C: Tensor, cutoffs: tuple[int] @@ -654,7 +688,8 @@ def hermite_renormalized_diagonal( r"""Firsts, reorder A and B parameters of Bargmann representation to match conventions in mrmustard.math.compactFock~ Then, calculates the required renormalized multidimensional Hermite polynomial. """ - return self._apply("hermite_renormalized_diagonal", (A, B, C, cutoffs)) + fn = self._get_fn("hermite_renormalized_diagonal") + return fn(A, B, C, cutoffs) def hermite_renormalized_diagonal_batch( self, A: Tensor, B: Tensor, C: Tensor, cutoffs: tuple[int] @@ -662,7 +697,8 @@ def hermite_renormalized_diagonal_batch( r"""First, reorder A and B parameters of Bargmann representation to match conventions in mrmustard.math.compactFock~ Then, calculates the required renormalized multidimensional Hermite polynomial. Same as hermite_renormalized_diagonal but works for a batch of different B's.""" - return self._apply("hermite_renormalized_diagonal_batch", (A, B, C, cutoffs)) + fn = self._get_fn("hermite_renormalized_diagonal_batch") + return fn(A, B, C, cutoffs) def hermite_renormalized_1leftoverMode( self, A: Tensor, B: Tensor, C: Tensor, cutoffs: tuple[int] @@ -670,7 +706,8 @@ def hermite_renormalized_1leftoverMode( r"""First, reorder A and B parameters of Bargmann representation to match conventions in mrmustard.math.compactFock~ Then, calculate the required renormalized multidimensional Hermite polynomial. """ - return self._apply("hermite_renormalized_1leftoverMode", (A, B, C, cutoffs)) + fn = self._get_fn("hermite_renormalized_1leftoverMode") + return fn(A, B, C, cutoffs) def imag(self, array: Tensor) -> Tensor: r"""The imaginary part of array. @@ -681,7 +718,8 @@ def imag(self, array: Tensor) -> Tensor: Returns: The imaginary part of array """ - return self._apply("imag", (array,)) + fn = self._get_fn("imag") + return fn(array) def inv(self, tensor: Tensor) -> Tensor: r"""The inverse of tensor. @@ -692,7 +730,8 @@ def inv(self, tensor: Tensor) -> Tensor: Returns: The inverse of tensor """ - return self._apply("inv", (tensor,)) + fn = self._get_fn("inv") + return fn(tensor) def is_trainable(self, tensor: Tensor) -> bool: r"""Whether the given tensor is trainable. @@ -703,7 +742,8 @@ def is_trainable(self, tensor: Tensor) -> bool: Returns: Whether the given tensor can be trained. """ - return self._apply("is_trainable", (tensor,)) + fn = self._get_fn("is_trainable") + return fn(tensor) def lgamma(self, x: Tensor) -> Tensor: r"""The natural logarithm of the gamma function of ``x``. @@ -714,7 +754,8 @@ def lgamma(self, x: Tensor) -> Tensor: Returns: The natural logarithm of the gamma function of ``x`` """ - return self._apply("lgamma", (x,)) + fn = self._get_fn("lgamma") + return fn(x) def log(self, x: Tensor) -> Tensor: r"""The natural logarithm of ``x``. @@ -725,7 +766,8 @@ def log(self, x: Tensor) -> Tensor: Returns: The natural logarithm of ``x`` """ - return self._apply("log", (x,)) + fn = self._get_fn("log") + return fn(x) def make_complex(self, real: Tensor, imag: Tensor) -> Tensor: """Given two real tensors representing the real and imaginary part of a complex number, @@ -738,7 +780,8 @@ def make_complex(self, real: Tensor, imag: Tensor) -> Tensor: Returns: The complex array ``real + 1j * imag``. """ - return self._apply("make_complex", (real, imag)) + fn = self._get_fn("make_complex") + return fn(real, imag) def matmul(self, *matrices: Matrix) -> Tensor: r"""The matrix product of the given matrices. @@ -749,7 +792,8 @@ def matmul(self, *matrices: Matrix) -> Tensor: Returns: The matrix product """ - return self._apply("matmul", matrices) + fn = self._get_fn("matmul") + return fn(*matrices) def matvec(self, a: Matrix, b: Vector) -> Tensor: r"""The matrix vector product of ``a`` (matrix) and ``b`` (vector). @@ -761,7 +805,8 @@ def matvec(self, a: Matrix, b: Vector) -> Tensor: Returns: The matrix vector product of ``a`` and ``b`` """ - return self._apply("matvec", (a, b)) + fn = self._get_fn("matvec") + return fn(a, b) def maximum(self, a: Tensor, b: Tensor) -> Tensor: r"""The element-wise maximum of ``a`` and ``b``. @@ -773,13 +818,8 @@ def maximum(self, a: Tensor, b: Tensor) -> Tensor: Returns: The element-wise maximum of ``a`` and ``b`` """ - return self._apply( - "maximum", - ( - a, - b, - ), - ) + fn = self._get_fn("maximum") + return fn(a, b) def minimum(self, a: Tensor, b: Tensor) -> Tensor: r"""The element-wise minimum of ``a`` and ``b``. @@ -791,13 +831,8 @@ def minimum(self, a: Tensor, b: Tensor) -> Tensor: Returns: The element-wise minimum of ``a`` and ``b`` """ - return self._apply( - "minimum", - ( - a, - b, - ), - ) + fn = self._get_fn("minimum") + return fn(a, b) def moveaxis(self, array: Tensor, old: Tensor, new: Tensor) -> Tensor: r""" @@ -812,14 +847,8 @@ def moveaxis(self, array: Tensor, old: Tensor, new: Tensor) -> Tensor: Returns: The updated array """ - return self._apply( - "moveaxis", - ( - array, - old, - new, - ), - ) + fn = self._get_fn("moveaxis") + return fn(array, old, new) def new_variable( self, @@ -838,7 +867,8 @@ def new_variable( Returns: The new variable. """ - return self._apply("new_variable", (value, bounds, name, dtype)) + fn = self._get_fn("new_variable") + return fn(value, bounds, name, dtype) def new_constant(self, value: Tensor, name: str, dtype=None) -> Tensor: r"""Returns a new constant with the given value. @@ -851,7 +881,8 @@ def new_constant(self, value: Tensor, name: str, dtype=None) -> Tensor: Returns: The new constant """ - return self._apply("new_constant", (value, name, dtype)) + fn = self._get_fn("new_constant") + return fn(value, name, dtype) def norm(self, array: Tensor) -> Tensor: r"""The norm of array. @@ -862,7 +893,8 @@ def norm(self, array: Tensor) -> Tensor: Returns: The norm of array """ - return self._apply("norm", (array,)) + fn = self._get_fn("norm") + return fn(array) def ones(self, shape: Sequence[int], dtype=None) -> Tensor: r"""Returns an array of ones with the given ``shape`` and ``dtype``. @@ -875,8 +907,8 @@ def ones(self, shape: Sequence[int], dtype=None) -> Tensor: Returns: The array of ones """ - # NOTE : should be float64 by default - return self._apply("ones", (shape, dtype)) + fn = self._get_fn("ones") + return fn(shape, dtype) def ones_like(self, array: Tensor) -> Tensor: r"""Returns an array of ones with the same shape and ``dtype`` as ``array``. @@ -887,7 +919,8 @@ def ones_like(self, array: Tensor) -> Tensor: Returns: The array of ones """ - return self._apply("ones_like", (array,)) + fn = self._get_fn("ones_like") + return fn(array) def outer(self, array1: Tensor, array2: Tensor) -> Tensor: r"""The outer product of ``array1`` and ``array2``. @@ -899,7 +932,8 @@ def outer(self, array1: Tensor, array2: Tensor) -> Tensor: Returns: The outer product of array1 and array2 """ - return self._apply("outer", (array1, array2)) + fn = self._get_fn("outer") + return fn(array1, array2) def pad( self, @@ -919,9 +953,10 @@ def pad( Returns: The padded array """ - return self._apply("pad", (array, paddings, mode, constant_values)) + fn = self._get_fn("pad") + return fn(array, paddings, mode, constant_values) - def pinv(self, matrix: Tensor) -> Tensor: + def pinv(self, matrix: Tensor) -> Tensor: # pragma: no cover r"""The pseudo-inverse of matrix. Args: @@ -930,7 +965,8 @@ def pinv(self, matrix: Tensor) -> Tensor: Returns: The pseudo-inverse of matrix """ - return self._apply("pinv", (matrix,)) + fn = self._get_fn("pinv") + return fn(matrix) def pow(self, x: Tensor, y: Tensor) -> Tensor: r"""Returns :math:`x^y`. Broadcasts ``x`` and ``y`` if necessary. @@ -941,7 +977,8 @@ def pow(self, x: Tensor, y: Tensor) -> Tensor: Returns: The :math:`x^y` """ - return self._apply("pow", (x, y)) + fn = self._get_fn("pow") + return fn(x, y) def kron(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: r""" @@ -954,7 +991,8 @@ def kron(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: Returns: The Kroenecker product. """ - return self._apply("kron", (tensor1, tensor2)) + fn = self._get_fn("kron") + return fn(tensor1, tensor2) def prod(self, array: Tensor, axis=None) -> Tensor: r""" @@ -968,7 +1006,8 @@ def prod(self, array: Tensor, axis=None) -> Tensor: Returns: The product of the elements in ``array``. """ - return self._apply("prod", (array, axis)) + fn = self._get_fn("prod") + return fn(array, axis) def real(self, array: Tensor) -> Tensor: r"""The real part of ``array``. @@ -979,7 +1018,8 @@ def real(self, array: Tensor) -> Tensor: Returns: The real part of ``array`` """ - return self._apply("real", (array,)) + fn = self._get_fn("real") + return fn(array) def repeat(self, array: Tensor, repeats: int, axis: int = None) -> Tensor: """ @@ -993,7 +1033,8 @@ def repeat(self, array: Tensor, repeats: int, axis: int = None) -> Tensor: Returns: The tensor with repeated elements. """ - return self._apply("repeat", (array, repeats, axis)) + fn = self._get_fn("repeat") + return fn(array, repeats, axis) def reshape(self, array: Tensor, shape: Sequence[int]) -> Tensor: r"""The reshaped array. @@ -1005,7 +1046,8 @@ def reshape(self, array: Tensor, shape: Sequence[int]) -> Tensor: Returns: The reshaped array """ - return self._apply("reshape", (array, shape)) + fn = self._get_fn("reshape") + return fn(array, shape) def round(self, array: Tensor, decimals: int) -> Tensor: r"""The array rounded to the nearest integer. @@ -1017,7 +1059,8 @@ def round(self, array: Tensor, decimals: int) -> Tensor: Returns: The array rounded to the nearest integer """ - return self._apply("round", (array, decimals)) + fn = self._get_fn("round") + return fn(array, decimals) def set_diag(self, array: Tensor, diag: Tensor, k: int) -> Tensor: r"""The array with the diagonal set to ``diag``. @@ -1030,7 +1073,8 @@ def set_diag(self, array: Tensor, diag: Tensor, k: int) -> Tensor: Returns: The array with the diagonal set to ``diag`` """ - return self._apply("set_diag", (array, diag, k)) + fn = self._get_fn("set_diag") + return fn(array, diag, k) def sin(self, array: Tensor) -> Tensor: r"""The sine of ``array``. @@ -1041,7 +1085,8 @@ def sin(self, array: Tensor) -> Tensor: Returns: The sine of ``array`` """ - return self._apply("sin", (array,)) + fn = self._get_fn("sin") + return fn(array) def sinh(self, array: Tensor) -> Tensor: r"""The hyperbolic sine of ``array``. @@ -1052,7 +1097,8 @@ def sinh(self, array: Tensor) -> Tensor: Returns: The hyperbolic sine of ``array`` """ - return self._apply("sinh", (array,)) + fn = self._get_fn("sinh") + return fn(array) def solve(self, matrix: Tensor, rhs: Tensor) -> Tensor: r"""The solution of the linear system :math:`Ax = b`. @@ -1064,7 +1110,8 @@ def solve(self, matrix: Tensor, rhs: Tensor) -> Tensor: Returns: The solution :math:`x` """ - return self._apply("solve", (matrix, rhs)) + fn = self._get_fn("solve") + return fn(matrix, rhs) def sort(self, array: Tensor, axis: int = -1) -> Tensor: r"""Sort the array along an axis. @@ -1076,7 +1123,8 @@ def sort(self, array: Tensor, axis: int = -1) -> Tensor: Returns: A sorted version of the array in acending order. """ - return self._apply("sort", (array, axis)) + fn = self._get_fn("sort") + return fn(array, axis) def sqrt(self, x: Tensor, dtype=None) -> Tensor: r"""The square root of ``x``. @@ -1088,7 +1136,8 @@ def sqrt(self, x: Tensor, dtype=None) -> Tensor: Returns: The square root of ``x`` """ - return self._apply("sqrt", (x, dtype)) + fn = self._get_fn("sqrt") + return fn(x, dtype) def sqrtm(self, tensor: Tensor, dtype=None) -> Tensor: r"""The matrix square root. @@ -1100,23 +1149,28 @@ def sqrtm(self, tensor: Tensor, dtype=None) -> Tensor: Returns: The square root of ``x``""" - return self._apply("sqrtm", (tensor, dtype)) + fn = self._get_fn("sqrtm") + return fn(tensor, dtype) - def sum(self, array: Tensor, axes: Sequence[int] = None): + def sum(self, array: Tensor, axis: Sequence[int] = None): r"""The sum of array. Args: array: The array to take the sum of - axes (tuple): axes to sum over + axis (tuple): The axis to sum over Returns: The sum of array """ - if axes is not None: - neg = [a for a in axes if a < 0] - pos = [a for a in axes if a >= 0] - axes = sorted(neg) + sorted(pos)[::-1] - return self._apply("sum", (array, axes)) + fn = self._get_fn("sum") + if axis is not None: + neg = [a for a in axis if a < 0] + pos = [a for a in axis if a >= 0] + axis = tuple(sorted(neg) + sorted(pos)[::-1]) + try: + return array.__array_ufunc__(fn, "reduce", array, axis=axis) + except AttributeError: + return fn(array, axis) def tensordot(self, a: Tensor, b: Tensor, axes: Sequence[int]) -> Tensor: r"""The tensordot product of ``a`` and ``b``. @@ -1129,7 +1183,8 @@ def tensordot(self, a: Tensor, b: Tensor, axes: Sequence[int]) -> Tensor: Returns: The tensordot product of ``a`` and ``b`` """ - return self._apply("tensordot", (a, b, axes)) + fn = self._get_fn("tensordot") + return fn(a, b, axes) def tile(self, array: Tensor, repeats: Sequence[int]) -> Tensor: r"""The tiled array. @@ -1141,7 +1196,8 @@ def tile(self, array: Tensor, repeats: Sequence[int]) -> Tensor: Returns: The tiled array """ - return self._apply("tile", (array, repeats)) + fn = self._get_fn("tile") + return fn(array, repeats) def trace(self, array: Tensor, dtype=None) -> Tensor: r"""The trace of array. @@ -1153,7 +1209,8 @@ def trace(self, array: Tensor, dtype=None) -> Tensor: Returns: The trace of array """ - return self._apply("trace", (array, dtype)) + fn = self._get_fn("trace") + return fn(array, dtype) def transpose(self, a: Tensor, perm: Sequence[int] = None): r"""The transposed arrays. @@ -1165,7 +1222,8 @@ def transpose(self, a: Tensor, perm: Sequence[int] = None): Returns: The transposed array """ - return self._apply("transpose", (a, perm)) + fn = self._get_fn("transpose") + return fn(a, perm) def update_tensor(self, tensor: Tensor, indices: Tensor, values: Tensor) -> Tensor: r"""Updates a tensor in place with the given values. @@ -1178,7 +1236,8 @@ def update_tensor(self, tensor: Tensor, indices: Tensor, values: Tensor) -> Tens Returns: The updated tensor """ - return self._apply("update_tensor", (tensor, indices, values)) + fn = self._get_fn("update_tensor") + return fn(tensor, indices, values) def update_add_tensor(self, tensor: Tensor, indices: Tensor, values: Tensor) -> Tensor: r"""Updates a tensor in place by adding the given values. @@ -1191,7 +1250,8 @@ def update_add_tensor(self, tensor: Tensor, indices: Tensor, values: Tensor) -> Returns: The updated tensor """ - return self._apply("update_add_tensor", (tensor, indices, values)) + fn = self._get_fn("update_add_tensor") + return fn(tensor, indices, values) def value_and_gradients( self, cost_fn: Callable, parameters: dict[str, list[Trainable]] @@ -1205,13 +1265,15 @@ def value_and_gradients( Returns: tuple: loss and gradients (dict) of the given cost function """ - return self._apply("value_and_gradients", (cost_fn, parameters)) + fn = self._get_fn("value_and_gradients") + return fn(cost_fn, parameters) def xlogy(self, x: Tensor, y: Tensor) -> Tensor: """ Returns ``0`` if ``x == 0`` elementwise and ``x * log(y)`` otherwise. """ - return self._apply("xlogy", (x, y)) + fn = self._get_fn("xlogy") + return fn(x, y) def zeros(self, shape: Sequence[int], dtype=None) -> Tensor: r"""Returns an array of zeros with the given shape and ``dtype``. @@ -1224,7 +1286,8 @@ def zeros(self, shape: Sequence[int], dtype=None) -> Tensor: Returns: The array of zeros. """ - return self._apply("zeros", (shape, dtype)) + fn = self._get_fn("zeros") + return fn(shape, dtype) def zeros_like(self, array: Tensor) -> Tensor: r"""Returns an array of zeros with the same shape and ``dtype`` as ``array``. @@ -1235,9 +1298,10 @@ def zeros_like(self, array: Tensor) -> Tensor: Returns: The array of zeros. """ - return self._apply("zeros_like", (array,)) + fn = self._get_fn("zeros_like") + return fn(array) - def map_fn(self, fn: Callable, elements: Tensor) -> Tensor: + def map_fn(self, func: Callable, elements: Tensor) -> Tensor: """Transforms elems by applying fn to each element unstacked on axis 0. Args: @@ -1250,9 +1314,10 @@ def map_fn(self, fn: Callable, elements: Tensor) -> Tensor: Returns: Tensor: applied ``func`` on ``elements`` """ - return self._apply("map_fn", (fn, elements)) + fn = self._get_fn("map_fn") + return fn(func, elements) - def squeeze(self, tensor: Tensor, axis: list[int] | None) -> Tensor: + def squeeze(self, tensor: Tensor, axis: list[int] | None) -> Tensor: # pragma: no cover """Removes dimensions of size 1 from the shape of a tensor. Args: @@ -1263,7 +1328,8 @@ def squeeze(self, tensor: Tensor, axis: list[int] | None) -> Tensor: Returns: Tensor: tensor with one or more dimensions of size 1 removed """ - return self._apply("squeeze", (tensor, axis)) + fn = self._get_fn("squeeze") + return fn(tensor, axis) def cholesky(self, input: Tensor) -> Tensor: """Computes the Cholesky decomposition of square matrices. @@ -1274,7 +1340,8 @@ def cholesky(self, input: Tensor) -> Tensor: Returns: Tensor: tensor with the same type as input """ - return self._apply("cholesky", (input,)) + fn = self._get_fn("cholesky") + return fn(input) def Categorical(self, probs: Tensor, name: str): """Categorical distribution over integers. @@ -1286,7 +1353,8 @@ def Categorical(self, probs: Tensor, name: str): Returns: tfp.distributions.Categorical: instance of ``tfp.distributions.Categorical`` class """ - return self._apply("Categorical", (probs, name)) + fn = self._get_fn("Categorical") + return fn(probs, name) def MultivariateNormalTriL(self, loc: Tensor, scale_tril: Tensor): """Multivariate normal distribution on `R^k` and parameterized by a (batch of) length-k loc @@ -1300,7 +1368,8 @@ def MultivariateNormalTriL(self, loc: Tensor, scale_tril: Tensor): Returns: tfp.distributions.MultivariateNormalTriL: instance of ``tfp.distributions.MultivariateNormalTriL`` """ - return self._apply("MultivariateNormalTriL", (loc, scale_tril)) + fn = self._get_fn("MultivariateNormalTriL") + return fn(loc, scale_tril) def custom_gradient(self, func): r""" @@ -1321,7 +1390,8 @@ def wrapper(*args, **kwargs): def DefaultEuclideanOptimizer(self): r"""Default optimizer for the Euclidean parameters.""" - return self._apply("DefaultEuclideanOptimizer") + fn = self._get_fn("DefaultEuclideanOptimizer") + return fn() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Methods that build on the basic ops and don't need to be overridden in the backend implementation diff --git a/mrmustard/math/backend_numpy.py b/mrmustard/math/backend_numpy.py index 452f7be0a..2e875ca21 100644 --- a/mrmustard/math/backend_numpy.py +++ b/mrmustard/math/backend_numpy.py @@ -14,7 +14,7 @@ """This module contains the numpy backend.""" -# pylint: disable = missing-function-docstring, missing-class-docstring, fixme, too-many-positional-arguments +# pylint: disable = missing-function-docstring, missing-class-docstring, fixme from __future__ import annotations @@ -390,14 +390,8 @@ def sort(self, array: np.ndarray, axis: int = -1) -> np.ndarray: def sqrt(self, x: np.ndarray, dtype=None) -> np.ndarray: return np.sqrt(self.cast(x, dtype)) - def sum(self, array: np.ndarray, axes: Sequence[int] = None): - if axes is None: - return np.sum(array) - - ret = array - for axis in axes: - ret = np.sum(ret, axis=axis) - return ret + def sum(self, array: np.ndarray, axis: Sequence[int] = None): + return np.sum(array, axis=axis) @Autocast() def tensordot(self, a: np.ndarray, b: np.ndarray, axes: list[int]) -> np.ndarray: diff --git a/mrmustard/math/backend_tensorflow.py b/mrmustard/math/backend_tensorflow.py index 9001dd235..474bd9760 100644 --- a/mrmustard/math/backend_tensorflow.py +++ b/mrmustard/math/backend_tensorflow.py @@ -336,8 +336,8 @@ def sort(self, array: tf.Tensor, axis: int = -1) -> tf.Tensor: def sqrt(self, x: tf.Tensor, dtype=None) -> tf.Tensor: return tf.sqrt(self.cast(x, dtype)) - def sum(self, array: tf.Tensor, axes: Sequence[int] = None): - return tf.reduce_sum(array, axes) + def sum(self, array: tf.Tensor, axis: Sequence[int] = None): + return tf.reduce_sum(array, axis) @Autocast() def tensordot(self, a: tf.Tensor, b: tf.Tensor, axes: list[int]) -> tf.Tensor: @@ -618,9 +618,9 @@ def grad(dLdpoly): ) ax = tuple(range(dLdpoly.ndim)) - dLdA = self.sum(dLdpoly[..., None, None] * self.conj(dpoly_dA), axes=ax) - dLdB = self.sum(dLdpoly[..., None] * self.conj(dpoly_dB), axes=ax) - dLdC = self.sum(dLdpoly * self.conj(dpoly_dC), axes=ax) + dLdA = self.sum(dLdpoly[..., None, None] * self.conj(dpoly_dA), axis=ax) + dLdB = self.sum(dLdpoly[..., None] * self.conj(dpoly_dB), axis=ax) + dLdC = self.sum(dLdpoly * self.conj(dpoly_dC), axis=ax) return dLdA, dLdB, dLdC return poly0, grad @@ -724,9 +724,9 @@ def grad(dLdpoly): ) ax = tuple(range(dLdpoly.ndim)) - dLdA = self.sum(dLdpoly[..., None, None] * self.conj(dpoly_dA), axes=ax) - dLdB = self.sum(dLdpoly[..., None] * self.conj(dpoly_dB), axes=ax) - dLdC = self.sum(dLdpoly * self.conj(dpoly_dC), axes=ax) + dLdA = self.sum(dLdpoly[..., None, None] * self.conj(dpoly_dA), axis=ax) + dLdB = self.sum(dLdpoly[..., None] * self.conj(dpoly_dB), axis=ax) + dLdC = self.sum(dLdpoly * self.conj(dpoly_dC), axis=ax) return dLdA, dLdB, dLdC return poly0, grad diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index d792957ef..99be058c8 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -27,9 +27,10 @@ from IPython.display import display from mrmustard import math, widgets -from mrmustard.utils.typing import Batch, Scalar, Tensor, Vector +from mrmustard.utils.typing import Scalar, Tensor, Vector from .base import Ansatz +from ..batches import Batch __all__ = ["ArrayAnsatz"] @@ -49,16 +50,17 @@ class ArrayAnsatz(Ansatz): Args: array: A (potentially) batched array. - batched: Whether the array input has a batch dimension. + batch_labels: The (optional) batch labels for this ansatz. - Note: The args can be passed non-batched, as they will be automatically broadcasted to the - correct batch shape if ``batched`` is set to ``False``. """ - def __init__(self, array: Batch[Tensor], batched=False): + def __init__(self, array: Tensor | Batch[Tensor], batch_labels: list[str] | None = None): super().__init__() - self._array = array if batched else [array] - self._backend_array = False + self._array = ( + Batch(math.astensor([array]), batch_labels=batch_labels) + if array is not None and not isinstance(array, Batch) + else array + ) self._original_abc_data = None @property @@ -67,25 +69,19 @@ def array(self) -> Batch[Tensor]: The array of this ansatz. """ self._generate_ansatz() - if not self._backend_array: - self._array = math.astensor(self._array) - self._backend_array = True return self._array @array.setter def array(self, value): - self._array = value - self._backend_array = False + self._array = value if isinstance(value, Batch) else Batch(math.astensor([value])) @property - def batch_size(self): - return self.array.shape[0] + def batch_size(self) -> int: + return sum(self.array.batch_shape) @property - def conj(self): - ret = ArrayAnsatz(math.conj(self.array), batched=True) - ret._contract_idxs = self._contract_idxs - return ret + def conj(self) -> ArrayAnsatz: + return ArrayAnsatz(Batch(math.conj(self.array.data))) @property def data(self) -> Batch[Tensor]: @@ -93,7 +89,7 @@ def data(self) -> Batch[Tensor]: @property def num_vars(self) -> int: - return len(self.array.shape) - 1 + return len(self.array.core_shape) @property def scalar(self) -> Scalar: @@ -102,7 +98,7 @@ def scalar(self) -> Scalar: I.e. the vacuum component of the Fock array, whatever it may be. Given that the first axis of the array is the batch axis, this is the first element of the array. """ - return self.array[(slice(None),) + (0,) * self.num_vars] + return self.array.data[(slice(None),) * len(self.array.batch_shape) + (0,) * self.num_vars] @property def triple(self) -> tuple: @@ -115,15 +111,57 @@ def triple(self) -> tuple: @classmethod def from_dict(cls, data: dict[str, ArrayLike]) -> ArrayAnsatz: - return cls(data["array"], batched=True) + return cls(data["array"]) @classmethod def from_function(cls, fn: Callable, **kwargs: Any) -> ArrayAnsatz: - ret = cls(None, True) + ret = cls(None) ret._fn = fn ret._kwargs = kwargs return ret + def contract( + self, + other: ArrayAnsatz, + idx1: int | tuple[int, ...] | None = None, + idx2: int | tuple[int, ...] | None = None, + ) -> ArrayAnsatz: + idx1 = idx1 or () + idx2 = idx2 or () + idx1 = (idx1,) if isinstance(idx1, int) else idx1 + idx2 = (idx2,) if isinstance(idx2, int) else idx2 + for i, j in zip(idx1, idx2): + if i >= self.num_vars: + raise IndexError( + f"Index {i} out of bounds for representation with {self.num_vars} variables." + ) + if j >= other.num_vars: + raise IndexError( + f"Index {j} out of bounds for representation with {other.num_vars} variables." + ) + + n_batches_s = self.batch_size + n_batches_o = other.batch_size + + shape_s = self.array.core_shape + shape_o = other.array.core_shape + + new_shape_s = list(shape_s) + new_shape_o = list(shape_o) + for s, o in zip(idx1, idx2): + new_shape_s[s] = min(shape_s[s], shape_o[o]) + new_shape_o[o] = min(shape_s[s], shape_o[o]) + + reduced_s = self.reduce(new_shape_s) + reduced_o = other.reduce(new_shape_o) + + axes = [list(idx1), list(idx2)] + batched_array = [] + for i in range(n_batches_s): + for j in range(n_batches_o): + batched_array.append(math.tensordot(reduced_s.array[i], reduced_o.array[j], axes)) + return ArrayAnsatz(Batch(math.astensor(batched_array))) + def reduce(self, shape: int | Sequence[int]) -> ArrayAnsatz: r""" Returns a new ``ArrayAnsatz`` with a sliced array. @@ -150,7 +188,7 @@ def reduce(self, shape: int | Sequence[int]) -> ArrayAnsatz: Args: shape: The shape of the array of the returned ``ArrayAnsatz``. """ - if shape == self.array.shape[1:]: + if shape == self.array.core_shape: return self length = self.num_vars shape = (shape,) * length if isinstance(shape, int) else shape @@ -159,21 +197,24 @@ def reduce(self, shape: int | Sequence[int]) -> ArrayAnsatz: msg += f"given shape has length {len(shape)}." raise ValueError(msg) - if any(s > t for s, t in zip(shape, self.array.shape[1:])): + if any(s > t for s, t in zip(shape, self.array.core_shape)): warn( "Warning: the fock array is being padded with zeros. If possible, slice the arrays this one will contract with instead." ) padded = math.pad( self.array, - [(0, 0)] + [(0, s - t) for s, t in zip(shape, self.array.shape[1:])], + [(0, 0)] * len(self.array.batch_shape) + + [(0, s - t) for s, t in zip(shape, self.array.core_shape)], ) - return ArrayAnsatz(padded, batched=True) + return ArrayAnsatz(Batch(padded)) - ret = self.array[(slice(0, None),) + tuple(slice(0, s) for s in shape)] - return ArrayAnsatz(array=ret, batched=True) + ret = self.array.data[ + (slice(None),) * len(self.array.batch_shape) + tuple(slice(0, s) for s in shape) + ] + return ArrayAnsatz(array=Batch(ret)) def reorder(self, order: tuple[int, ...] | list[int]) -> ArrayAnsatz: - return ArrayAnsatz(math.transpose(self.array, [0] + [i + 1 for i in order]), batched=True) + return ArrayAnsatz(Batch(math.transpose(self.array, [0] + [i + 1 for i in order]))) def sum_batch(self) -> ArrayAnsatz: r""" @@ -182,7 +223,7 @@ def sum_batch(self) -> ArrayAnsatz: Returns: The collapsed ArrayAnsatz object. """ - return ArrayAnsatz(math.expand_dims(math.sum(self.array, axes=[0]), 0), batched=True) + return ArrayAnsatz(Batch(math.expand_dims(math.sum(self.array, axis=[0]), 0))) def to_dict(self) -> dict[str, ArrayLike]: return {"array": self.data} @@ -192,7 +233,7 @@ def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...]) -> ArrayAnsa raise ValueError("The idxs must be of equal length and disjoint.") order = ( [0] - + [i + 1 for i in range(len(self.array.shape) - 1) if i not in idx_z + idx_zconj] + + [i + 1 for i in range(len(self.array.core_shape)) if i not in idx_z + idx_zconj] + [i + 1 for i in idx_z] + [i + 1 for i in idx_zconj] ) @@ -200,11 +241,11 @@ def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...]) -> ArrayAnsa n = np.prod(new_array.shape[-len(idx_zconj) :]) new_array = math.reshape(new_array, new_array.shape[: -2 * len(idx_z)] + (n, n)) trace = math.trace(new_array) - return ArrayAnsatz([trace] if trace.shape == () else trace, batched=True) + return ArrayAnsatz(Batch(math.astensor([trace]) if trace.shape == () else trace)) def _generate_ansatz(self): if self._array is None: - self.array = [self._fn(**self._kwargs)] + self.array = self._fn(**self._kwargs) def _ipython_display_(self): if widgets.IN_INTERACTIVE_SHELL or (w := widgets.fock(self)) is None: @@ -214,70 +255,34 @@ def _ipython_display_(self): def __add__(self, other: ArrayAnsatz) -> ArrayAnsatz: try: - diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) + diff = sum(self.array.core_shape) - sum(other.array.core_shape) if diff < 0: new_array = [ - a + b for a in self.reduce(other.array.shape[1:]).array for b in other.array + a + b for a in self.reduce(other.array.core_shape).array for b in other.array ] else: new_array = [ - a + b for a in self.array for b in other.reduce(self.array.shape[1:]).array + a + b for a in self.array for b in other.reduce(self.array.core_shape).array ] - return ArrayAnsatz(array=new_array, batched=True) + return ArrayAnsatz(array=Batch(math.astensor(new_array))) except Exception as e: raise TypeError(f"Cannot add {self.__class__} and {other.__class__}.") from e def __and__(self, other: ArrayAnsatz) -> ArrayAnsatz: new_array = [math.outer(a, b) for a in self.array for b in other.array] - return ArrayAnsatz(array=new_array, batched=True) + return ArrayAnsatz(array=Batch(math.astensor(new_array))) def __call__(self, z: Batch[Vector]) -> Scalar: raise AttributeError("Cannot call this ArrayAnsatz.") def __eq__(self, other: Ansatz) -> bool: - slices = (slice(0, None),) + tuple( - slice(0, min(si, oi)) for si, oi in zip(self.array.shape[1:], other.array.shape[1:]) + slices = (slice(None),) * len(self.array.batch_shape) + tuple( + slice(0, min(si, oi)) for si, oi in zip(self.array.core_shape, other.array.core_shape) ) - return np.allclose(self.array[slices], other.array[slices], atol=1e-10) - - def __getitem__(self, idx: int | tuple[int, ...]) -> ArrayAnsatz: - idx = (idx,) if isinstance(idx, int) else idx - for i in idx: - if i >= self.num_vars: - raise IndexError( - f"Index {i} out of bounds for representation with {self.num_vars} variables." - ) - ret = ArrayAnsatz(self.array, batched=True) - ret._contract_idxs = idx - return ret - - def __matmul__(self, other: ArrayAnsatz) -> ArrayAnsatz: - idx_s = list(self._contract_idxs) - idx_o = list(other._contract_idxs) - - # the number of batches in self and other - n_batches_s = self.array.shape[0] - n_batches_o = other.array.shape[0] - - # the shapes each batch in self and other - shape_s = self.array.shape[1:] - shape_o = other.array.shape[1:] + return np.allclose(self.array.data[slices], other.array.data[slices], atol=1e-10) - new_shape_s = list(shape_s) - new_shape_o = list(shape_o) - for s, o in zip(idx_s, idx_o): - new_shape_s[s] = min(shape_s[s], shape_o[o]) - new_shape_o[o] = min(shape_s[s], shape_o[o]) - - reduced_s = self.reduce(new_shape_s)[idx_s] - reduced_o = other.reduce(new_shape_o)[idx_o] - - axes = [list(idx_s), list(idx_o)] - batched_array = [] - for i in range(n_batches_s): - for j in range(n_batches_o): - batched_array.append(math.tensordot(reduced_s.array[i], reduced_o.array[j], axes)) - return ArrayAnsatz(batched_array, batched=True) + def __getitem__(self, idxs: int | slice | tuple[int, ...] | tuple[slice, ...]) -> ArrayAnsatz: + return ArrayAnsatz(self.array[idxs]) def __mul__(self, other: Scalar | ArrayAnsatz) -> ArrayAnsatz: if isinstance(other, ArrayAnsatz): @@ -291,11 +296,11 @@ def __mul__(self, other: Scalar | ArrayAnsatz) -> ArrayAnsatz: new_array = [ a * b for a in self.array for b in other.reduce(self.array.shape[1:]).array ] - return ArrayAnsatz(array=new_array, batched=True) + return ArrayAnsatz(array=Batch(math.astensor(new_array))) except Exception as e: raise TypeError(f"Cannot multiply {self.__class__} and {other.__class__}.") from e else: - ret = ArrayAnsatz(array=self.array * other, batched=True) + ret = ArrayAnsatz(array=self.array * other) ret._original_abc_data = ( tuple(i * j for i, j in zip(self._original_abc_data, (1, 1, other))) if self._original_abc_data is not None @@ -304,7 +309,7 @@ def __mul__(self, other: Scalar | ArrayAnsatz) -> ArrayAnsatz: return ret def __neg__(self) -> ArrayAnsatz: - return ArrayAnsatz(array=-self.array, batched=True) + return ArrayAnsatz(array=-self.array) def __truediv__(self, other: Scalar | ArrayAnsatz) -> ArrayAnsatz: if isinstance(other, ArrayAnsatz): @@ -318,11 +323,11 @@ def __truediv__(self, other: Scalar | ArrayAnsatz) -> ArrayAnsatz: new_array = [ a / b for a in self.array for b in other.reduce(self.array.shape[1:]).array ] - return ArrayAnsatz(array=new_array, batched=True) + return ArrayAnsatz(array=Batch(math.astensor(new_array))) except Exception as e: raise TypeError(f"Cannot divide {self.__class__} and {other.__class__}.") from e else: - ret = ArrayAnsatz(array=self.array / other, batched=True) + ret = ArrayAnsatz(array=self.array / other) ret._original_abc_data = ( tuple(i / j for i, j in zip(self._original_abc_data, (1, 1, other))) if self._original_abc_data is not None diff --git a/mrmustard/physics/ansatz/base.py b/mrmustard/physics/ansatz/base.py index 26086bc7d..9a0323b83 100644 --- a/mrmustard/physics/ansatz/base.py +++ b/mrmustard/physics/ansatz/base.py @@ -41,7 +41,6 @@ class Ansatz(ABC): """ def __init__(self) -> None: - self._contract_idxs: tuple[int, ...] = () self._fn = None self._kwargs = {} @@ -105,6 +104,25 @@ def from_function(cls, fn: Callable, **kwargs: Any) -> Ansatz: Returns an ansatz from a function and kwargs. """ + @abstractmethod + def contract( + self, + other: Ansatz, + idx1: int | tuple[int, ...] | None = None, + idx2: int | tuple[int, ...] | None = None, + ) -> Ansatz: + r""" + Contract two ansatz together. + + Args: + other: Another ansatz. + idx1: The (optional) index of the first ansatz to contract. + idx2: The (optional) index of the second ansatz to contract. + + Returns: + The resulting contracted ansatz. + """ + @abstractmethod def reorder(self, order: tuple[int, ...] | list[int]) -> Ansatz: r""" @@ -180,21 +198,15 @@ def __eq__(self, other: Ansatz) -> bool: """ @abstractmethod - def __getitem__(self, idx: int | tuple[int, ...]) -> Ansatz: + def __getitem__(self, idxs: int | slice | tuple[int, ...] | tuple[slice, ...]) -> Ansatz: r""" - Returns a copy of self with the given indices marked for contraction. - """ - - @abstractmethod - def __matmul__(self, other: Ansatz) -> Ansatz: - r""" - Implements the inner product of representations over the marked indices. + Slices this ansatz. Args: - other: Another ansatz. + idxs: The indices to slice the ansatz with. Returns: - The resulting ansatz. + The sliced ansatz. """ @abstractmethod diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index d65d28e7c..ebd6bd454 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -32,7 +32,6 @@ from IPython.display import display from mrmustard.utils.typing import ( - Batch, ComplexMatrix, ComplexTensor, ComplexVector, @@ -52,6 +51,7 @@ from mrmustard.utils.argsort import argsort_gen from .base import Ansatz +from ..batches import Batch __all__ = ["PolyExpAnsatz"] @@ -87,20 +87,34 @@ class PolyExpAnsatz(Ansatz): A: A batch of quadratic coefficient :math:`A_i`. b: A batch of linear coefficients :math:`b_i`. c: A batch of arrays :math:`c_i`. + batch_labels: The (optional) batch labels of this ansatz. + name: The (optional) name of this ansatz. """ def __init__( self, - A: Batch[ComplexMatrix], - b: Batch[ComplexVector], - c: Batch[ComplexTensor] = 1.0, + A: ComplexMatrix | Batch[ComplexMatrix] | None, + b: ComplexVector | Batch[ComplexVector] | None, + c: ComplexTensor | Batch[ComplexTensor] | None, + batch_labels: list[str] | None = None, name: str = "", ): super().__init__() - self._A = A - self._b = b - self._c = c - self._backends = [False, False, False] + self._A = ( + Batch(math.atleast_3d(A), batch_labels=batch_labels) + if A is not None and not isinstance(A, Batch) + else A + ) + self._b = ( + Batch(math.atleast_2d(b), batch_labels=batch_labels) + if b is not None and not isinstance(b, Batch) + else b + ) + self._c = ( + Batch(math.atleast_1d(c), batch_labels=batch_labels) + if c is not None and not isinstance(c, Batch) + else c + ) self._simplified = False self.name = name @@ -110,15 +124,11 @@ def A(self) -> Batch[ComplexMatrix]: The batch of quadratic coefficient :math:`A_i`. """ self._generate_ansatz() - if not self._backends[0]: - self._A = math.atleast_3d(self._A) - self._backends[0] = True return self._A @A.setter - def A(self, value): - self._A = value - self._backends[0] = False + def A(self, value: ComplexMatrix | Batch[ComplexMatrix]): + self._A = value if isinstance(value, Batch) else Batch(math.atleast_3d(value)) @property def b(self) -> Batch[ComplexVector]: @@ -126,19 +136,15 @@ def b(self) -> Batch[ComplexVector]: The batch of linear coefficients :math:`b_i` """ self._generate_ansatz() - if not self._backends[1]: - self._b = math.atleast_2d(self._b) - self._backends[1] = True return self._b @b.setter - def b(self, value): - self._b = value - self._backends[1] = False + def b(self, value: ComplexVector | Batch[ComplexVector]): + self._b = value if isinstance(value, Batch) else Batch(math.atleast_2d(value)) @property def batch_size(self): - return self.c.shape[0] + return sum(self.c.batch_shape) @property def c(self) -> Batch[ComplexTensor]: @@ -146,21 +152,15 @@ def c(self) -> Batch[ComplexTensor]: The batch of arrays :math:`c_i`. """ self._generate_ansatz() - if not self._backends[2]: - self._c = math.atleast_1d(self._c) - self._backends[2] = True return self._c @c.setter - def c(self, value): - self._c = value - self._backends[2] = False + def c(self, value: ComplexTensor | Batch[ComplexTensor]): + self._c = value if isinstance(value, Batch) else Batch(math.atleast_1d(value)) @property - def conj(self): - ret = PolyExpAnsatz(math.conj(self.A), math.conj(self.b), math.conj(self.c)) - ret._contract_idxs = self._contract_idxs - return ret + def conj(self) -> PolyExpAnsatz: + return PolyExpAnsatz(math.conj(self.A), math.conj(self.b), math.conj(self.c)) @property def data( @@ -169,8 +169,8 @@ def data( return self.triple @property - def num_vars(self): - return self.A.shape[-1] - self.polynomial_shape[0] + def num_vars(self) -> int: + return self.A.core_shape[-1] - self.polynomial_shape[0] @property def polynomial_shape(self) -> tuple[int, tuple]: @@ -179,7 +179,7 @@ def polynomial_shape(self) -> tuple[int, tuple]: have polynomials attached to them and what the degree(+1) of the polynomial is on each of the wires. """ - shape_poly = self.c.shape[1:] + shape_poly = self.c.core_shape dim_poly = len(shape_poly) return dim_poly, shape_poly @@ -207,6 +207,39 @@ def from_function(cls, fn: Callable, **kwargs: Any) -> PolyExpAnsatz: ansatz._kwargs = kwargs return ansatz + def contract( + self, + other: PolyExpAnsatz, + idx1: int | tuple[int, ...] | None = None, + idx2: int | tuple[int, ...] | None = None, + ) -> PolyExpAnsatz: + idx1 = idx1 or () + idx2 = idx2 or () + idx1 = (idx1,) if isinstance(idx1, int) else idx1 + idx2 = (idx2,) if isinstance(idx2, int) else idx2 + for i, j in zip(idx1, idx2): + if i and i >= self.num_vars: + raise IndexError( + f"Index {i} out of bounds for ansatz of dimension {self.num_vars}." + ) + if j and j >= other.num_vars: + raise IndexError( + f"Index {j} out of bounds for ansatz of dimension {other.num_vars}." + ) + + if settings.UNSAFE_ZIP_BATCH: + if self.batch_size != other.batch_size: + raise ValueError( + f"Batch size of the two representations must match since the settings.UNSAFE_ZIP_BATCH is {settings.UNSAFE_ZIP_BATCH}." + ) + A, b, c = complex_gaussian_integral_2(self.triple, other.triple, idx1, idx2, mode="zip") + else: + A, b, c = complex_gaussian_integral_2( + self.triple, other.triple, idx1, idx2, mode="kron" + ) + + return PolyExpAnsatz(A, b, c) + def decompose_ansatz(self) -> PolyExpAnsatz: r""" This method decomposes a PolyExp ansatz. Given an ansatz of dimension: @@ -214,10 +247,15 @@ def decompose_ansatz(self) -> PolyExpAnsatz: it can be rewritten as an ansatz of dimension A=(batch,2n,2n), b=(batch,2n), c = (batch,l_1,l_2,...,l_n), with l_i = sum_j k_j This decomposition is typically favourable if m>n, and will only run if that is the case. - The naming convention is ``n = dim_alpha`` and ``m = dim_beta`` and ``(k_1,k_2,...,k_m) = shape_beta`` + The naming convention is ``n = dim_alpha`` and ``m = dim_beta`` and ``(k_1,k_2,...,k_m) = shape_beta``. + + Raises: + NotImplementedError: If the number of batch dimensions is greater than 1. """ + if len(self.c.batch_shape) > 1: # pragma: no cover + raise NotImplementedError("``decompose_ansatz`` is only compatible with 1-D batches.") dim_beta, _ = self.polynomial_shape - dim_alpha = self.A.shape[-1] - dim_beta + dim_alpha = self.A.core_shape[-1] - dim_beta batch_size = self.batch_size if dim_beta > dim_alpha: A_decomp = [] @@ -233,7 +271,7 @@ def decompose_ansatz(self) -> PolyExpAnsatz: return PolyExpAnsatz(A_decomp, b_decomp, c_decomp) else: - return PolyExpAnsatz(self.A, self.b, self.c) + return self def plot( self, @@ -310,17 +348,20 @@ def simplify(self) -> None: return indices_to_check = set(range(self.batch_size)) removed = [] + temp_c = self.c.data while indices_to_check: i = indices_to_check.pop() for j in indices_to_check.copy(): - if np.allclose(self.A[i], self.A[j]) and np.allclose(self.b[i], self.b[j]): - self.c = math.update_add_tensor(self.c, [[i]], [self.c[j]]) + if math.allclose(self.A[i], self.A[j]) and math.allclose(self.b[i], self.b[j]): + temp_c = math.update_add_tensor(temp_c, [[i]], [temp_c[j]]) indices_to_check.remove(j) removed.append(j) to_keep = [i for i in range(self.batch_size) if i not in removed] + self.A = math.gather(self.A, to_keep, axis=0) self.b = math.gather(self.b, to_keep, axis=0) - self.c = math.gather(self.c, to_keep, axis=0) + self.c = math.gather(temp_c, to_keep, axis=0) + self._simplified = True def simplify_v2(self) -> None: @@ -332,16 +373,17 @@ def simplify_v2(self) -> None: self._order_batch() to_keep = [d0 := 0] mat, vec = self.A[d0], self.b[d0] + temp_c = self.c.data for d in range(1, self.batch_size): if np.allclose(mat, self.A[d]) and np.allclose(vec, self.b[d]): - self.c = math.update_add_tensor(self.c, [[d0]], [self.c[d]]) + temp_c = math.update_add_tensor(temp_c, [[d0]], [temp_c[d]]) else: to_keep.append(d) d0 = d mat, vec = self.A[d0], self.b[d0] self.A = math.gather(self.A, to_keep, axis=0) self.b = math.gather(self.b, to_keep, axis=0) - self.c = math.gather(self.c, to_keep, axis=0) + self.c = math.gather(temp_c, to_keep, axis=0) self._simplified = True def to_dict(self) -> dict[str, ArrayLike]: @@ -379,26 +421,26 @@ def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz: zz = math.einsum("...a,...b->...ab", z, z)[..., None, :, :] # shape (b_arg, 1, n, n)) A_part = math.sum( - self.A[..., :dim_alpha, :dim_alpha] * zz, axes=[-1, -2] + self.A.data[..., :dim_alpha, :dim_alpha] * zz, axis=[-1, -2] ) # sum((b_arg,1,n,n) * (b_abc,n,n), [-1,-2]) ~ (b_arg,b_abc) b_part = math.sum( - self.b[..., :dim_alpha] * z[..., None, :], axes=[-1] + self.b.data[..., :dim_alpha] * z[..., None, :], axis=[-1] ) # sum((b_arg,1,n) * (b_abc,n), [-1]) ~ (b_arg,b_abc) exp_sum = math.exp(1 / 2 * A_part + b_part) # (b_arg, b_abc) if dim_beta == 0: - val = math.sum(exp_sum * self.c, axes=[-1]) # (b_arg) + val = math.sum(exp_sum * self.c.data, axis=[-1]) # (b_arg) else: b_poly = math.astensor( math.einsum( "ijk,hk", - math.cast(self.A[..., dim_alpha:, :dim_alpha], "complex128"), + math.cast(self.A.data[..., dim_alpha:, :dim_alpha], "complex128"), math.cast(z, "complex128"), ) - + self.b[..., dim_alpha:] + + self.b.data[..., dim_alpha:] ) # (b_arg, b_abc, m) b_poly = math.moveaxis(b_poly, 0, 1) # (b_abc, b_arg, m) - A_poly = self.A[..., dim_alpha:, dim_alpha:] # (b_abc, m) + A_poly = self.A.data[..., dim_alpha:, dim_alpha:] # (b_abc, m) poly = math.astensor( [ math.hermite_renormalized_batch(A_poly[i], b_poly[i], complex(1), shape_beta) @@ -409,10 +451,10 @@ def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz: val = math.sum( exp_sum * math.sum( - poly * self.c, - axes=math.arange(2, 2 + dim_beta, dtype=math.int32).tolist(), + poly * self.c.data, + axis=math.arange(2, 2 + dim_beta, dtype=math.int32).tolist(), ), - axes=[-1], + axis=[-1], ) # (b_arg) return val @@ -442,7 +484,10 @@ def _call_none(self, z: Batch[Vector]) -> PolyExpAnsatz: arg_index = 0 if batch_arg == 1 else i Abc.append( self._call_none_single( - self.A[abc_index], self.b[abc_index], self.c[abc_index], z[arg_index] + self.A.data[abc_index], + self.b.data[abc_index], + self.c.data[abc_index], + z[arg_index], ) ) A, b, c = zip(*Abc) @@ -520,7 +565,7 @@ def _decompose_ansatz_single(self, Ai, bi, ci): ) c_decomp = math.sum( poly_bar * ci, - axes=math.arange( + axis=math.arange( len(poly_bar.shape) - dim_beta, len(poly_bar.shape), dtype=math.int32 ).tolist(), ) @@ -542,7 +587,9 @@ def _decompose_ansatz_single(self, Ai, bi, ci): def _equal_no_array(self, other: PolyExpAnsatz) -> bool: self.simplify() other.simplify() - return np.allclose(self.b, other.b, atol=1e-10) and np.allclose(self.A, other.A, atol=1e-10) + return math.allclose(self.b, other.b, atol=1e-10) and math.allclose( + self.A, other.A, atol=1e-10 + ) def _generate_ansatz(self): r""" @@ -603,8 +650,8 @@ def __add__(self, other: PolyExpAnsatz) -> PolyExpAnsatz: """ if not isinstance(other, PolyExpAnsatz): raise TypeError(f"Cannot add {self.__class__} and {other.__class__}.") - (_, n1, _) = self.A.shape - (_, n2, _) = other.A.shape + (n1, _) = self.A.core_shape + (n2, _) = other.A.core_shape self_num_poly, _ = self.polynomial_shape other_num_poly, _ = other.polynomial_shape if self_num_poly - other_num_poly != n1 - n2: @@ -641,6 +688,7 @@ def combine_arrays(array1, array2): combined_matrices = math.concat([self.A, mat2], axis=0) combined_vectors = math.concat([self.b, vec2], axis=0) combined_arrays = combine_arrays(self.c, array2) + # note output is not simplified return PolyExpAnsatz(combined_matrices, combined_vectors, combined_arrays) @@ -754,58 +802,8 @@ def __eq__(self, other: PolyExpAnsatz) -> bool: return False return self._equal_no_array(other) and np.allclose(self.c, other.c, atol=1e-10) - def __getitem__(self, idx: int | tuple[int, ...]) -> PolyExpAnsatz: - idx = (idx,) if isinstance(idx, int) else idx - for i in idx: - if i >= self.num_vars: - raise IndexError( - f"Index {i} out of bounds for ansatz of dimension {self.num_vars}." - ) - ret = PolyExpAnsatz(self.A, self.b, self.c) - ret._contract_idxs = idx - return ret - - def __matmul__(self, other: PolyExpAnsatz) -> PolyExpAnsatz: - r""" - Implements the inner product between PolyExpAnsatz. - - ..code-block:: - - >>> from mrmustard.physics.ansatz import PolyExpAnsatz - >>> from mrmustard.physics.triples import displacement_gate_Abc, vacuum_state_Abc - >>> rep1 = PolyExpAnsatz(*vacuum_state_Abc(1)) - >>> rep2 = PolyExpAnsatz(*displacement_gate_Abc(1)) - >>> rep3 = rep1[0] @ rep2[1] - >>> assert np.allclose(rep3.A, [[0,],]) - >>> assert np.allclose(rep3.b, [1,]) - - Args: - other: Another PolyExpAnsatz . - - Returns: - Bargmann: the resulting PolyExpAnsatz. - - """ - if not isinstance(other, PolyExpAnsatz): - raise NotImplementedError("Only matmul PolyExpAnsatz with PolyExpAnsatz") - - idx_s = self._contract_idxs - idx_o = other._contract_idxs - - if settings.UNSAFE_ZIP_BATCH: - if self.batch_size != other.batch_size: - raise ValueError( - f"Batch size of the two representations must match since the settings.UNSAFE_ZIP_BATCH is {settings.UNSAFE_ZIP_BATCH}." - ) - A, b, c = complex_gaussian_integral_2( - self.triple, other.triple, idx_s, idx_o, mode="zip" - ) - else: - A, b, c = complex_gaussian_integral_2( - self.triple, other.triple, idx_s, idx_o, mode="kron" - ) - - return PolyExpAnsatz(A, b, c) + def __getitem__(self, idxs: int | slice | tuple[int, ...] | tuple[slice, ...]) -> PolyExpAnsatz: + return PolyExpAnsatz(self.A[idxs], self.b[idxs], self.c[idxs]) def __mul__(self, other: Scalar | PolyExpAnsatz) -> PolyExpAnsatz: def mul_A(A1, A2, dim_alpha, dim_beta1, dim_beta2): diff --git a/mrmustard/physics/batches.py b/mrmustard/physics/batches.py new file mode 100644 index 000000000..fa6bd012e --- /dev/null +++ b/mrmustard/physics/batches.py @@ -0,0 +1,188 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains the Batch class. +""" + +# pylint: disable=too-many-instance-attributes + +from __future__ import annotations +from typing import Any, Collection, Iterable + +import string +import random + +from mrmustard import math +from mrmustard.utils.typing import ( + ComplexMatrix, + ComplexTensor, + ComplexVector, + Scalar, +) + +__all__ = ["Batch"] + + +class Batch: + r""" + The class responsible for keeping track of and handling batch dimensions. + + Args: + data: The batched array. + batch_shape: The (optional) shape of the batch dims. Defaults to the first dimension of ``data``. + batch_labels: The (optional) labels for the batch dims. Defaults to random characters. + """ + + def __init__( + self, + data: ComplexMatrix | ComplexVector | ComplexTensor, + batch_shape: tuple[int, ...] | None = None, + batch_labels: tuple[str, ...] | None = None, + ): + self._data = data + self.dtype = self._data.dtype + self._batch_shape = batch_shape or self._data.shape[:1] + if self._data.shape[: len(self._batch_shape)] != self._batch_shape: + raise ValueError( + f"Invalid batch shape {self._batch_shape} for data shape {self._data.shape}." + ) + self._batch_labels = ( + batch_labels + if batch_labels + else tuple((random.choice(string.ascii_letters) for _ in self._batch_shape)) + ) + self._core_shape = self._data.shape[len(self._batch_shape) :] + + @property + def batch_labels(self) -> tuple[str, ...]: + r""" + The batch labels. + """ + return self._batch_labels + + @property + def batch_shape(self) -> tuple[int, ...]: + r""" + The batch shape. + """ + return self._batch_shape + + @property + def core_shape(self) -> tuple[int, ...]: + r""" + The core shape. + """ + return self._core_shape + + @property + def data(self) -> ComplexMatrix | ComplexVector | ComplexTensor: + r""" + The underlying batched data. + """ + return self._data + + @property + def shape(self) -> tuple[int, ...]: + r""" + The overall shape (batch_shape + core_shape). + """ + return self.data.shape + + def __array__(self): + return math.asnumpy(self.data) + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # pragma: no cover + r""" + Implement the NumPy ufunc interface. + """ + if method == "__call__": + inputs = [i.data if isinstance(i, Batch) else i for i in inputs] + return Batch(ufunc(*inputs, **kwargs), self.batch_shape, self.batch_labels) + elif method == "reduce": + axes = kwargs.get("axis") or (0,) + + if any(axis > len(self.batch_shape) - 1 for axis in axes): + raise ValueError("Axis out of bounds.") + input = ( + inputs[0].data if isinstance(inputs[0], Batch) else inputs[0] + ) # assume single input + batch_shape = tuple( + (shape for idx, shape in enumerate(self.batch_shape) if idx not in axes) + ) + batch_labels = tuple( + (label for idx, label in enumerate(self.batch_labels) if idx not in axes) + ) + data = ufunc(input, **kwargs) + return Batch(data, batch_shape, batch_labels) if batch_shape else data + + else: + # TODO: implement more methods as needed + raise NotImplementedError(f"Cannot call {method} on {ufunc}.") + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Batch): + return False + return ( + math.allclose(self.data, other.data) + and self.batch_shape == other.batch_shape + and self.batch_labels == other.batch_labels + ) + + def __getitem__( + self, idxs: int | slice | tuple[int, ...] | tuple[slice, ...] + ) -> ComplexMatrix | ComplexVector | ComplexTensor | Batch: + r""" + Index the batch dimensions. + + Note: + To index core dimensions use ``self.data``. + """ + idxs = (idxs,) if not isinstance(idxs, Collection) else idxs + if len(idxs) > len(self.batch_shape): + raise IndexError( + f"Too many indices for batched array: batch is {len(self.batch_shape)}-dimensional, but {len(idxs)} were indexed." + ) + new_data = self.data[idxs] + new_batch_shape = ( + new_data.shape[: len(self.core_shape) - 1] + if len(self.core_shape) < len(new_data.shape) + else () + ) + new_batch_labels = ( + tuple(self.batch_labels[i] for i, j in enumerate(idxs) if isinstance(j, slice)) + + self.batch_labels[len(idxs) :] + ) + return Batch(new_data, new_batch_shape, new_batch_labels) if new_batch_shape else new_data + + def __iter__(self) -> Iterable: + return iter(self.data) + + def __len__(self) -> int: + return len(self.data) + + def __mul__(self, other: Scalar) -> Batch: + return Batch(self.data * other, self.batch_shape, self.batch_labels) + + def __neg__(self) -> Batch: + return -1 * self + + def __rmul__(self, other: Scalar) -> Batch: + return self * other + + def __rtruediv__(self, other: Scalar) -> Batch: + return Batch(other / self.data, self.batch_shape, self.batch_labels) + + def __truediv__(self, other: Scalar) -> Batch: + return Batch(self.data / other, self.batch_shape, self.batch_labels) diff --git a/mrmustard/physics/fock_utils.py b/mrmustard/physics/fock_utils.py index d18f98343..d8a1b0db4 100644 --- a/mrmustard/physics/fock_utils.py +++ b/mrmustard/physics/fock_utils.py @@ -342,7 +342,7 @@ def number_means(tensor, is_dm: bool): r"""Returns the mean of the number operator in each mode.""" probs = math.all_diagonals(tensor, real=True) if is_dm else math.abs(tensor) ** 2 modes = list(range(len(probs.shape))) - marginals = [math.sum(probs, axes=modes[:k] + modes[k + 1 :]) for k in range(len(modes))] + marginals = [math.sum(probs, axis=modes[:k] + modes[k + 1 :]) for k in range(len(modes))] return math.astensor( [ math.sum(marginal * math.arange(len(marginal), dtype=math.float64)) @@ -355,7 +355,7 @@ def number_variances(tensor, is_dm: bool): r"""Returns the variance of the number operator in each mode.""" probs = math.all_diagonals(tensor, real=True) if is_dm else math.abs(tensor) ** 2 modes = list(range(len(probs.shape))) - marginals = [math.sum(probs, axes=modes[:k] + modes[k + 1 :]) for k in range(len(modes))] + marginals = [math.sum(probs, axis=modes[:k] + modes[k + 1 :]) for k in range(len(modes))] return math.astensor( [ ( diff --git a/mrmustard/physics/gaussian_integrals.py b/mrmustard/physics/gaussian_integrals.py index 954f243f0..c36eddc4e 100644 --- a/mrmustard/physics/gaussian_integrals.py +++ b/mrmustard/physics/gaussian_integrals.py @@ -382,7 +382,7 @@ def complex_gaussian_integral_1( inv_M = math.inv(M) c_post = c * math.reshape( math.sqrt(math.cast((-1) ** m / det_M, "complex128")) - * math.exp(-0.5 * math.sum(bM * math.solve(M, bM), axes=[-1])), + * math.exp(-0.5 * math.sum(bM * math.solve(M, bM), axis=[-1])), c.shape[:1] + (1,) * (len(c.shape) - 1), ) A_post = R - math.einsum("bij,bjk,blk->bil", D, inv_M, D) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 053495b69..3d2f52246 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -27,10 +27,10 @@ ComplexTensor, ComplexMatrix, ComplexVector, - Batch, ) from .ansatz import Ansatz, PolyExpAnsatz, ArrayAnsatz +from .batches import Batch from .triples import identity_Abc from .wires import Wires @@ -208,7 +208,9 @@ def bargmann_triple( except AttributeError as e: raise AttributeError("No Bargmann data for this component.") from e - def fock_array(self, shape: int | Sequence[int], batched=False) -> ComplexTensor: + def fock_array( + self, shape: int | Sequence[int], batched: bool = False + ) -> ComplexTensor | Batch[ComplexTensor]: r""" Returns an array of this representation in the Fock basis with the given shape. If the shape is not given, it defaults to the ``auto_shape`` of the component if it is @@ -233,17 +235,18 @@ def fock_array(self, shape: int | Sequence[int], batched=False) -> ComplexTensor ) if self.ansatz.polynomial_shape[0] == 0: arrays = [ - math.hermite_renormalized(A, b, c, shape=shape) for A, b, c in zip(As, bs, cs) + math.hermite_renormalized(A, b, c, shape=shape) + for A, b, c in zip(As.data, bs.data, cs.data) ] else: arrays = [ math.sum( math.hermite_renormalized(A, b, 1, shape=shape + c.shape) * c, - axes=math.arange( + axis=math.arange( num_vars, num_vars + len(c.shape), dtype=math.int32 ).tolist(), ) - for A, b, c in zip(As, bs, cs) + for A, b, c in zip(As.data, bs.data, cs.data) ] except AttributeError as e: if len(shape) != num_vars: @@ -251,8 +254,8 @@ def fock_array(self, shape: int | Sequence[int], batched=False) -> ComplexTensor f"Expected Fock shape of length {num_vars}, got length {len(shape)}" ) from e arrays = self.ansatz.reduce(shape).array - array = math.sum(arrays, axes=[0]) - arrays = math.expand_dims(array, 0) if batched else array + array = math.sum(arrays, axis=[0]) + arrays = Batch(math.expand_dims(array, 0)) if batched else array return arrays def to_bargmann(self) -> Representation: @@ -279,7 +282,7 @@ def to_fock(self, shape: int | Sequence[int]) -> Representation: an ``int``, it is broadcasted to all the dimensions. If ``None``, it defaults to the value of ``AUTOSHAPE_MAX`` in the settings. """ - fock = ArrayAnsatz(self.fock_array(shape, batched=True), batched=True) + fock = ArrayAnsatz(self.fock_array(shape, batched=True)) try: if self.ansatz.polynomial_shape[0] == 0: fock._original_abc_data = self.ansatz.triple @@ -344,8 +347,7 @@ def __matmul__(self, other: Representation): else: self_ansatz = self.to_bargmann().ansatz other_ansatz = other.to_bargmann().ansatz - - rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj] + rep = self_ansatz.contract(other_ansatz, idx_z, idx_zconj) rep = rep.reorder(perm) if perm else rep idx_reps = self._matmul_idx_reps(wires_result, other) return Representation(rep, wires_result, idx_reps) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index aa3bbbc97..f7eca2f88 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -498,7 +498,7 @@ def test_to_fock_keeps_bargmann(self): def test_fock_component_no_bargmann(self): "tests that a fock component doesn't have a bargmann representation by default" coh = Coherent([0], x=1.0) - CC = Ket.from_fock([0], coh.fock_array(20), batched=False) + CC = Ket.from_fock([0], coh.fock_array(20)) with pytest.raises(AttributeError): CC.bargmann_triple() # pylint: disable=pointless-statement diff --git a/tests/test_lab_dev/test_states/test_dm.py b/tests/test_lab_dev/test_states/test_dm.py index 61ef0047c..85f467571 100644 --- a/tests/test_lab_dev/test_states/test_dm.py +++ b/tests/test_lab_dev/test_states/test_dm.py @@ -26,6 +26,7 @@ from mrmustard.lab_dev.circuit_components_utils import TraceOut from mrmustard.lab_dev.states import DM, Coherent, Ket, Number, Vacuum from mrmustard.lab_dev.transformations import Attenuator, Dgate +from mrmustard.physics.batches import Batch from mrmustard.physics.gaussian import vacuum_cov from mrmustard.physics.representations import Representation from mrmustard.physics.wires import Wires @@ -99,7 +100,7 @@ def test_from_fock_error(self): state01 = Coherent([0, 1], 1).dm() state01 = state01.to_fock(2) with pytest.raises(ValueError): - DM.from_fock([0], state01.fock_array(5), "my_dm", True) + DM.from_fock([0], Batch(state01.fock_array(5)), "my_dm") def test_bargmann_triple_error(self): fock = Number([0], n=10).dm() @@ -128,7 +129,7 @@ def test_to_from_fock(self, modes): assert math.allclose(array_in, state_in_fock.ansatz.array) - state_out = DM.from_fock(modes, array_in, "my_dm", True) + state_out = DM.from_fock(modes, array_in, "my_dm") assert state_in_fock == state_out def test_to_from_phase_space(self): diff --git a/tests/test_lab_dev/test_states/test_ket.py b/tests/test_lab_dev/test_states/test_ket.py index 071ef0c32..1bf434e98 100644 --- a/tests/test_lab_dev/test_states/test_ket.py +++ b/tests/test_lab_dev/test_states/test_ket.py @@ -131,7 +131,7 @@ def test_to_from_fock(self, modes): assert math.allclose(array_in, state_in_fock.ansatz.array) - state_out = Ket.from_fock(modes, array_in, "my_ket", True) + state_out = Ket.from_fock(modes, array_in, "my_ket") assert state_in_fock == state_out @pytest.mark.parametrize("modes", [[0], [0, 1], [3, 19, 2]]) diff --git a/tests/test_lab_dev/test_transformations/test_cft.py b/tests/test_lab_dev/test_transformations/test_cft.py index 95370a47f..f93812ca7 100644 --- a/tests/test_lab_dev/test_transformations/test_cft.py +++ b/tests/test_lab_dev/test_transformations/test_cft.py @@ -41,7 +41,7 @@ def test_wigner_function(self): state = Ket.random([0]) >> Dgate([0], x=1.0, y=0.1) - dm = math.sum(state.to_fock(100).dm().ansatz.array, axes=[0]) + dm = math.sum(state.to_fock(100).dm().ansatz.array, axis=[0]) vec = np.linspace(-5, 5, 100) wigner, _, _ = wigner_discretized(dm, vec, vec) diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index 9cc0cdb84..40c1d003e 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -46,8 +46,8 @@ def test_init_from_bargmann(self): b = np.array([0, 1, 5]) c = 1 operator = Operation.from_bargmann([0], [1, 2], (A, b, c), "my_operator") - assert np.allclose(operator.ansatz.A[None, ...], A) - assert np.allclose(operator.ansatz.b[None, ...], b) + assert np.allclose(operator.ansatz.A[None], A) + assert np.allclose(operator.ansatz.b[None], b) class TestUnitary: @@ -92,8 +92,8 @@ def test_init_from_bargmann(self): b = np.array([0, 0]) c = 1 gate = Unitary.from_bargmann([2], [2], (A, b, c), "my_unitary") - assert np.allclose(gate.ansatz.A[None, ...], A) - assert np.allclose(gate.ansatz.b[None, ...], b) + assert np.allclose(gate.ansatz.A[None], A) + assert np.allclose(gate.ansatz.b[None], b) def test_init_from_symplectic(self): S = math.random_symplectic(2) @@ -132,8 +132,8 @@ def test_init_from_bargmann(self): b = np.array([0, 1, 2, 3]) c = 1 map = Map.from_bargmann([0], [0], (A, b, c), "my_map") - assert np.allclose(map.ansatz.A[None, ...], A) - assert np.allclose(map.ansatz.b[None, ...], b) + assert np.allclose(map.ansatz.A[None], A) + assert np.allclose(map.ansatz.b[None], b) class TestChannel: @@ -160,8 +160,8 @@ def test_init_from_bargmann(self): b = np.array([0, 1, 2, 3]) c = 1 channel = Channel.from_bargmann([0], [0], (A, b, c), "my_channel") - assert np.allclose(channel.ansatz.A[None, ...], A) - assert np.allclose(channel.ansatz.b[None, ...], b) + assert np.allclose(channel.ansatz.A[None], A) + assert np.allclose(channel.ansatz.b[None], b) def test_rshift(self): unitary = Dgate([0, 1], 1) @@ -194,7 +194,7 @@ def test_random(self): @pytest.mark.parametrize("modes", [[0], [0, 1], [0, 1, 2]]) def test_is_CP(self, modes): u = Unitary.random(modes).ansatz - kraus = u @ u.conj + kraus = u.contract(u.conj) assert Channel.from_bargmann(modes, modes, kraus.triple).is_CP def test_is_TP(self): @@ -206,7 +206,7 @@ def test_is_physical(self): def test_XY(self): U = Unitary.random([0, 1]) u = U.ansatz - unitary_channel = Channel.from_bargmann([0, 1], [0, 1], (u.conj @ u).triple) + unitary_channel = Channel.from_bargmann([0, 1], [0, 1], u.conj.contract(u).triple) X, Y = unitary_channel.XY assert np.allclose(X, U.symplectic) and np.allclose(Y, np.zeros(4)) diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index 4fff18c69..5670428fe 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -47,7 +47,7 @@ def test_error(self): """ msg = f"Function ``ciao`` not implemented for backend ``{math.backend_name}``." with pytest.raises(NotImplementedError, match=msg): - math._apply("ciao") + math._get_fn("ciao") def test_types(self): r""" diff --git a/tests/test_physics/test_ansatz/test_array_ansatz.py b/tests/test_physics/test_ansatz/test_array_ansatz.py index 410435d40..5aeff6bfb 100644 --- a/tests/test_physics/test_ansatz/test_array_ansatz.py +++ b/tests/test_physics/test_ansatz/test_array_ansatz.py @@ -25,6 +25,7 @@ from mrmustard import math from mrmustard.physics.ansatz.array_ansatz import ArrayAnsatz +from mrmustard.physics.batches import Batch class TestArrayAnsatz: @@ -36,19 +37,19 @@ class TestArrayAnsatz: array5578 = np.random.random((5, 5, 7, 8)) def test_init_batched(self): - fock = ArrayAnsatz(self.array1578, batched=True) + fock = ArrayAnsatz(Batch(self.array1578)) assert isinstance(fock, ArrayAnsatz) assert np.allclose(fock.array, self.array1578) def test_init_non_batched(self): - fock = ArrayAnsatz(self.array578, batched=False) + fock = ArrayAnsatz(self.array578) assert isinstance(fock, ArrayAnsatz) assert fock.array.shape == (1, 5, 7, 8) - assert np.allclose(fock.array[0, :, :, :], self.array578) + assert np.allclose(fock.array[0], self.array578) def test_add(self): - fock1 = ArrayAnsatz(self.array2578, batched=True) - fock2 = ArrayAnsatz(self.array5578, batched=True) + fock1 = ArrayAnsatz(Batch(self.array2578)) + fock2 = ArrayAnsatz(Batch(self.array5578)) fock1_add_fock2 = fock1 + fock2 assert fock1_add_fock2.array.shape == (10, 5, 7, 8) assert np.allclose(fock1_add_fock2.array[0], self.array2578[0] + self.array5578[0]) @@ -77,8 +78,8 @@ def test_algebra_with_different_shape_of_array_raise_errors(self): aa1 == aa2 # pylint: disable=pointless-statement def test_and(self): - fock1 = ArrayAnsatz(self.array1578, batched=True) - fock2 = ArrayAnsatz(self.array5578, batched=True) + fock1 = ArrayAnsatz(Batch(self.array1578)) + fock2 = ArrayAnsatz(Batch(self.array5578)) fock_test = fock1 & fock2 assert fock_test.array.shape == (5, 5, 7, 8, 5, 7, 8) assert np.allclose( @@ -87,17 +88,28 @@ def test_and(self): ) def test_call(self): - fock = ArrayAnsatz(self.array1578, batched=True) + fock = ArrayAnsatz(Batch(self.array1578)) with pytest.raises(AttributeError, match="Cannot call"): fock(0) def test_conj(self): - fock = ArrayAnsatz(self.array1578, batched=True) + fock = ArrayAnsatz(Batch(self.array1578)) fock_conj = fock.conj assert np.allclose(fock_conj.array, np.conj(self.array1578)) + def test_contract_fock_fock(self): + array2 = math.astensor(np.random.random((5, 6, 7, 8, 10))) + fock1 = ArrayAnsatz(Batch(self.array2578)) + fock2 = ArrayAnsatz(Batch(array2)) + fock_test = fock1.contract(fock2, 2, 2) + assert fock_test.array.shape == (10, 5, 7, 6, 7, 10) + assert np.allclose( + math.reshape(fock_test.array, -1), + math.reshape(np.einsum("bcde, pfgeh -> bpcdfgh", self.array2578, array2), -1), + ) + def test_divide_on_a_scalar(self): - fock1 = ArrayAnsatz(self.array1578, batched=True) + fock1 = ArrayAnsatz(Batch(self.array1578)) fock_test = fock1 / 1.5 assert np.allclose(fock_test.array, self.array1578 / 1.5) @@ -107,20 +119,17 @@ def test_equal(self): aa2 = ArrayAnsatz(array=array) assert aa1 == aa2 - def test_matmul_fock_fock(self): - array2 = math.astensor(np.random.random((5, 6, 7, 8, 10))) - fock1 = ArrayAnsatz(self.array2578, batched=True) - fock2 = ArrayAnsatz(array2, batched=True) - fock_test = fock1[2] @ fock2[2] - assert fock_test.array.shape == (10, 5, 7, 6, 7, 10) - assert np.allclose( - math.reshape(fock_test.array, -1), - math.reshape(np.einsum("bcde, pfgeh -> bpcdfgh", self.array2578, array2), -1), + def test_getitem(self): + batched_ansatz = ArrayAnsatz( + Batch(self.array5578, batch_shape=(5, 5), batch_labels=("a", "b")) ) + sliced_ansatz = batched_ansatz[0] + expected_ansatz = ArrayAnsatz(Batch(self.array5578[0], batch_labels=("b",))) + assert sliced_ansatz.array == expected_ansatz.array def test_mul(self): - fock1 = ArrayAnsatz(self.array1578, batched=True) - fock2 = ArrayAnsatz(self.array5578, batched=True) + fock1 = ArrayAnsatz(Batch(self.array1578)) + fock2 = ArrayAnsatz(Batch(self.array5578)) fock1_mul_fock2 = fock1 * fock2 assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) assert np.allclose( @@ -129,7 +138,7 @@ def test_mul(self): ) def test_multiply_a_scalar(self): - fock1 = ArrayAnsatz(self.array1578, batched=True) + fock1 = ArrayAnsatz(Batch(self.array1578)) fock_test = 1.3 * fock1 assert np.allclose(fock_test.array, 1.3 * self.array1578) @@ -144,7 +153,7 @@ def test_neg(self): def test_reduce(self, batched): shape = (1, 3, 3, 3) if batched else (3, 3, 3) array1 = math.astensor(np.arange(27).reshape(shape)) - fock1 = ArrayAnsatz(array1, batched=batched) + fock1 = ArrayAnsatz(Batch(array1) if batched else array1) fock2 = fock1.reduce(3) assert fock1 == fock2 @@ -175,14 +184,14 @@ def test_reduce_padded(self): def test_reorder(self): array1 = math.astensor(np.arange(8).reshape((1, 2, 2, 2))) - fock1 = ArrayAnsatz(array1, batched=True) + fock1 = ArrayAnsatz(Batch(array1)) fock2 = fock1.reorder(order=(2, 1, 0)) assert np.allclose(fock2.array, np.array([[[[0, 4], [2, 6]], [[1, 5], [3, 7]]]])) assert np.allclose(fock2.array, np.arange(8).reshape((1, 2, 2, 2), order="F")) def test_sub(self): - fock1 = ArrayAnsatz(self.array2578, batched=True) - fock2 = ArrayAnsatz(self.array5578, batched=True) + fock1 = ArrayAnsatz(Batch(self.array2578)) + fock2 = ArrayAnsatz(Batch(self.array5578)) fock1_sub_fock2 = fock1 - fock2 assert fock1_sub_fock2.array.shape == (10, 5, 7, 8) assert np.allclose(fock1_sub_fock2.array[0], self.array2578[0] - self.array5578[0]) @@ -190,26 +199,26 @@ def test_sub(self): assert np.allclose(fock1_sub_fock2.array[9], self.array2578[1] - self.array5578[4]) def test_sum_batch(self): - fock = ArrayAnsatz(self.array2578, batched=True) - fock_collapsed = fock.sum_batch()[0] + fock = ArrayAnsatz(Batch(self.array2578)) + fock_collapsed = fock.sum_batch() assert fock_collapsed.array.shape == (1, 5, 7, 8) assert np.allclose(fock_collapsed.array, np.sum(self.array2578, axis=0)) def test_to_from_dict(self): array1 = math.astensor(np.random.random((2, 5, 5, 1, 7, 4, 1, 7, 3))) - fock1 = ArrayAnsatz(array1, batched=True) + fock1 = ArrayAnsatz(Batch(array1)) assert ArrayAnsatz.from_dict(fock1.to_dict()) == fock1 def test_trace(self): array1 = math.astensor(np.random.random((2, 5, 5, 1, 7, 4, 1, 7, 3))) - fock1 = ArrayAnsatz(array1, batched=True) + fock1 = ArrayAnsatz(Batch(array1)) fock2 = fock1.trace([0, 3], [1, 6]) assert fock2.array.shape == (2, 1, 4, 1, 3) assert np.allclose(fock2.array, np.einsum("bccefghfj -> beghj", array1)) def test_truediv(self): - fock1 = ArrayAnsatz(self.array1578, batched=True) - fock2 = ArrayAnsatz(self.array5578, batched=True) + fock1 = ArrayAnsatz(Batch(self.array1578)) + fock2 = ArrayAnsatz(Batch(self.array5578)) fock1_mul_fock2 = fock1 / fock2 assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) assert np.allclose( @@ -228,7 +237,7 @@ def test_truediv_a_scalar(self): @patch("mrmustard.physics.ansatz.array_ansatz.display") def test_ipython_repr(self, mock_display, shape): """Test the IPython repr function.""" - rep = ArrayAnsatz(np.random.random(shape), batched=True) + rep = ArrayAnsatz(Batch(np.random.random(shape))) rep._ipython_display_() [hbox] = mock_display.call_args.args assert isinstance(hbox, HBox) @@ -251,21 +260,21 @@ def test_ipython_repr(self, mock_display, shape): @patch("mrmustard.physics.ansatz.array_ansatz.display") def test_ipython_repr_expects_batch_1(self, mock_display): """Test the IPython repr function does nothing with real batch.""" - rep = ArrayAnsatz(np.random.random((2, 8)), batched=True) + rep = ArrayAnsatz(Batch(np.random.random((2, 8)))) rep._ipython_display_() mock_display.assert_not_called() @patch("mrmustard.physics.ansatz.array_ansatz.display") def test_ipython_repr_expects_3_dims_or_less(self, mock_display): """Test the IPython repr function does nothing with 4+ dims.""" - rep = ArrayAnsatz(np.random.random((1, 4, 4, 4)), batched=True) + rep = ArrayAnsatz(Batch(np.random.random((1, 4, 4, 4)))) rep._ipython_display_() mock_display.assert_not_called() @patch("mrmustard.widgets.IN_INTERACTIVE_SHELL", True) def test_ipython_repr_interactive(self, capsys): """Test the IPython repr function.""" - rep = ArrayAnsatz(np.random.random((1, 8)), batched=True) + rep = ArrayAnsatz(Batch(np.random.random((1, 8)))) rep._ipython_display_() captured = capsys.readouterr() assert captured.out.rstrip() == repr(rep) diff --git a/tests/test_physics/test_ansatz/test_polyexp_ansatz.py b/tests/test_physics/test_ansatz/test_polyexp_ansatz.py index b432c7374..3d11d4766 100644 --- a/tests/test_physics/test_ansatz/test_polyexp_ansatz.py +++ b/tests/test_physics/test_ansatz/test_polyexp_ansatz.py @@ -25,6 +25,7 @@ from mrmustard import math from mrmustard.physics.ansatz.array_ansatz import ArrayAnsatz +from mrmustard.physics.batches import Batch from mrmustard.physics.ansatz.polyexp_ansatz import PolyExpAnsatz from mrmustard.physics.gaussian_integrals import ( complex_gaussian_integral_1, @@ -103,7 +104,7 @@ def test_add_different_poly_wires(self): def test_add_error(self): bargmann = PolyExpAnsatz(*Abc_triple(3)) - fock = ArrayAnsatz(np.random.random((1, 4, 4, 4)), batched=True) + fock = ArrayAnsatz(Batch(np.random.random((1, 4, 4, 4)))) with pytest.raises(TypeError, match="Cannot add"): bargmann + fock # pylint: disable=pointless-statement @@ -186,6 +187,16 @@ def test_conj(self, triple): assert np.allclose(bargmann.b, math.conj(b)) assert np.allclose(bargmann.c, math.conj(c)) + def test_contract_barg_barg(self): + triple1 = Abc_triple(3) + triple2 = Abc_triple(3) + + res1 = PolyExpAnsatz(*triple1).contract(PolyExpAnsatz(*triple2)) + exp1 = complex_gaussian_integral_2(triple1, triple2, [], []) + assert np.allclose(res1.A, exp1[0]) + assert np.allclose(res1.b, exp1[1]) + assert np.allclose(res1.c, exp1[2]) + def test_decompose_ansatz(self): A, b, _ = Abc_triple(4) c = np.random.uniform(-10, 10, size=(1, 3, 3, 3)) @@ -193,13 +204,13 @@ def test_decompose_ansatz(self): decomp_ansatz = ansatz.decompose_ansatz() z = np.random.uniform(-10, 10, size=(1, 1)) - assert np.allclose(ansatz(z), decomp_ansatz(z)) - assert np.allclose(decomp_ansatz.A.shape, (1, 2, 2)) + assert math.allclose(ansatz(z), decomp_ansatz(z)) + assert math.allclose(decomp_ansatz.A.shape, (1, 2, 2)) c2 = np.random.uniform(-10, 10, size=(1, 4)) ansatz2 = PolyExpAnsatz(A, b, c2) decomp_ansatz2 = ansatz2.decompose_ansatz() - assert np.allclose(decomp_ansatz2.A, ansatz2.A) + assert math.allclose(decomp_ansatz2.A, ansatz2.A) def test_decompose_ansatz_batch(self): """ @@ -240,9 +251,9 @@ def test_div(self, n): bargmann2 = PolyExpAnsatz(*triple2) bargmann_div = bargmann1 / bargmann2 - assert np.allclose(bargmann_div.A, bargmann1.A - bargmann2.A) - assert np.allclose(bargmann_div.b, bargmann1.b - bargmann2.b) - assert np.allclose(bargmann_div.c, bargmann1.c / bargmann2.c) + assert np.allclose(bargmann_div.A, bargmann1.A.data - bargmann2.A.data) + assert np.allclose(bargmann_div.b, bargmann1.b.data - bargmann2.b.data) + assert np.allclose(bargmann_div.c, bargmann1.c.data / bargmann2.c.data) @pytest.mark.parametrize("scalar", [0.5, 1.2]) @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) @@ -265,6 +276,19 @@ def test_eq(self): assert ansatz != ansatz2 assert ansatz2 != ansatz + def test_getitem(self): + A, b, c = Abc_triple(5) + batched_ansatz = PolyExpAnsatz( + Batch(math.astensor([A, A, A])), + Batch(math.astensor([b, b, b])), + Batch(math.astensor([c, c, c])), + ) + sliced_ansatz = batched_ansatz[0] + expected_ansatz = PolyExpAnsatz(A, b, c) + assert math.allclose(sliced_ansatz.A, expected_ansatz.A) + assert math.allclose(sliced_ansatz.b, expected_ansatz.b) + assert math.allclose(sliced_ansatz.c, expected_ansatz.c) + def test_inconsistent_poly_shapes(self): A1 = np.random.random((1, 2, 2)) A2 = np.random.random((1, 3, 3)) @@ -329,16 +353,6 @@ def test_ipython_repr_interactive(self, capsys): captured = capsys.readouterr() assert captured.out.rstrip() == repr(rep) - def test_matmul_barg_barg(self): - triple1 = Abc_triple(3) - triple2 = Abc_triple(3) - - res1 = PolyExpAnsatz(*triple1) @ PolyExpAnsatz(*triple2) - exp1 = complex_gaussian_integral_2(triple1, triple2, [], []) - assert np.allclose(res1.A, exp1[0]) - assert np.allclose(res1.b, exp1[1]) - assert np.allclose(res1.c, exp1[2]) - @pytest.mark.parametrize("n", [1, 2, 3]) def test_mul(self, n): triple1 = Abc_triple(n) @@ -348,9 +362,9 @@ def test_mul(self, n): bargmann2 = PolyExpAnsatz(*triple2) bargmann_mul = bargmann1 * bargmann2 - assert np.allclose(bargmann_mul.A, bargmann1.A + bargmann2.A) - assert np.allclose(bargmann_mul.b, bargmann1.b + bargmann2.b) - assert np.allclose(bargmann_mul.c, bargmann1.c * bargmann2.c) + assert np.allclose(bargmann_mul.A, bargmann1.A.data + bargmann2.A.data) + assert np.allclose(bargmann_mul.b, bargmann1.b.data + bargmann2.b.data) + assert np.allclose(bargmann_mul.c, bargmann1.c.data * bargmann2.c.data) @pytest.mark.parametrize("scalar", [0.5, 1.2]) @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) @@ -414,15 +428,15 @@ def test_simplify(self): ansatz = ansatz + ansatz - assert np.allclose(ansatz.A[0], ansatz.A[1]) - assert np.allclose(ansatz.A[0], A) - assert np.allclose(ansatz.b[0], ansatz.b[1]) - assert np.allclose(ansatz.b[0], b) + assert math.allclose(ansatz.A[0], ansatz.A[1]) + assert math.allclose(ansatz.A[0], A) + assert math.allclose(ansatz.b[0], ansatz.b[1]) + assert math.allclose(ansatz.b[0], b) ansatz.simplify() assert len(ansatz.A) == 1 assert len(ansatz.b) == 1 - assert ansatz.c == 2 * c + assert math.allclose(ansatz.c, 2 * c) def test_simplify_v2(self): A, b, c = Abc_triple(5) diff --git a/tests/test_physics/test_batches.py b/tests/test_physics/test_batches.py new file mode 100644 index 000000000..ffca74594 --- /dev/null +++ b/tests/test_physics/test_batches.py @@ -0,0 +1,88 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Batch class.""" + +# pylint: disable=missing-function-docstring + +import pytest +import numpy as np + +from mrmustard import math +from mrmustard.physics.batches import Batch + + +class TestBatch: + r""" + Tests for the Batch class. + """ + + array5688 = np.random.random((5, 6, 8, 8)) + + def test_init(self): + batch_default = Batch(data=self.array5688) + assert math.allclose(batch_default.data, self.array5688) + assert len(batch_default.batch_labels) == 1 + assert batch_default.batch_shape == (5,) + assert batch_default.core_shape == (6, 8, 8) + assert batch_default.shape == self.array5688.shape + + batch = Batch(data=self.array5688, batch_shape=(5, 6), batch_labels=("a", "b")) + assert math.allclose(batch.data, self.array5688) + assert batch.batch_labels == ("a", "b") + assert batch.batch_shape == (5, 6) + assert batch.core_shape == (8, 8) + assert batch.shape == self.array5688.shape + + assert math.allclose(batch_default, batch) + assert batch_default != batch + + with pytest.raises(ValueError, match="batch shape"): + Batch(self.array5688, batch_shape=(6, 6)) # pylint: disable=pointless-statement + + def test_getitem(self): + batch = Batch(data=self.array5688, batch_shape=(5, 6), batch_labels=("a", "b")) + + batch_slice0 = batch[0] + assert isinstance(batch_slice0, Batch) + assert math.allclose(batch_slice0.data, self.array5688[0]) + assert batch_slice0.batch_labels == ("b",) + assert batch_slice0.batch_shape == (6,) + assert batch_slice0.core_shape == (8, 8) + assert batch_slice0.shape == (6, 8, 8) + + batch_slice1 = batch[:, 0] + assert isinstance(batch_slice1, Batch) + assert math.allclose(batch_slice1.data, self.array5688[:, 0]) + assert batch_slice1.batch_labels == ("a",) + assert batch_slice1.batch_shape == (5,) + assert batch_slice1.core_shape == (8, 8) + assert batch_slice1.shape == (5, 8, 8) + + batch_slice2 = batch[0, 0] + assert not isinstance(batch_slice2, Batch) + assert math.allclose(batch_slice2, self.array5688[0, 0]) + + with pytest.raises(IndexError, match="indices"): + batch[:, :, 0] # pylint: disable=pointless-statement + + def test_ufunc(self): + batch = Batch(data=self.array5688, batch_shape=(5, 6), batch_labels=("a", "b")) + # __call__ + assert math.allclose(math.exp(batch), math.exp(self.array5688)) + # reduce + assert math.allclose(math.sum(batch, axis=(1,)), math.sum(self.array5688, axis=(1,))) + + with pytest.raises(ValueError, match="out of bounds"): + math.sum(batch, axis=(2,)) # pylint: disable=pointless-statement diff --git a/tests/test_physics/test_representations.py b/tests/test_physics/test_representations.py index 52d047bad..40ba69aa6 100644 --- a/tests/test_physics/test_representations.py +++ b/tests/test_physics/test_representations.py @@ -101,7 +101,9 @@ def test_matmul_btoq(self, d_gate_rep, btoq_rep): def test_to_bargmann(self, d_gate_rep): d_fock = d_gate_rep.to_fock(shape=(4, 6)) d_barg = d_fock.to_bargmann() + assert d_fock.ansatz._original_abc_data == d_gate_rep.ansatz.triple + assert math.allclose(d_barg.bargmann_triple()[0], d_gate_rep.bargmann_triple()[0]) assert d_barg == d_gate_rep assert all((k[0] == RepEnum.BARGMANN for k in d_barg._idx_reps.values())) diff --git a/tests/test_physics/test_triples.py b/tests/test_physics/test_triples.py index c84796231..0026ecba1 100644 --- a/tests/test_physics/test_triples.py +++ b/tests/test_physics/test_triples.py @@ -335,7 +335,7 @@ def test_displacement_gate_s_parametrized_Abc(self): def test_attenuator_kraus_Abc(self, eta): B = PolyExpAnsatz(*triples.attenuator_kraus_Abc(eta)) Att = PolyExpAnsatz(*triples.attenuator_Abc(eta)) - assert B[2] @ B[2] == Att + assert B.contract(B, 2, 2) == Att def test_gaussian_random_noise_Abc(self):