Skip to content

Commit

Permalink
convolution added to numpy backend (#517)
Browse files Browse the repository at this point in the history
Co-authored-by: Anthony <[email protected]>
  • Loading branch information
ziofil and apchytr authored Nov 27, 2024
1 parent f3c226d commit 4c603d0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
3 changes: 1 addition & 2 deletions mrmustard/math/backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,8 +1527,7 @@ def all_diagonals(self, rho: Tensor, real: bool) -> Tensor:

def poisson(self, max_k: int, rate: Tensor) -> Tensor:
"""Poisson distribution up to ``max_k``."""
k = self.arange(max_k)
rate = self.cast(rate, k.dtype)
k = self.arange(max_k, dtype=rate.dtype)
return self.exp(k * self.log(rate + 1e-9) - rate - self.lgamma(k + 1.0))

def binomial_conditional_prob(self, success_prob: Tensor, dim_out: int, dim_in: int):
Expand Down
56 changes: 55 additions & 1 deletion mrmustard/math/backend_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""This module contains the numpy backend."""

# pylint: disable = missing-function-docstring, missing-class-docstring, fixme
# pylint: disable = missing-function-docstring, missing-class-docstring, fixme, too-many-positional-arguments

from __future__ import annotations

Expand All @@ -23,6 +23,7 @@

import numpy as np
import scipy as sp
from scipy.signal import convolve2d as scipy_convolve2d
from scipy.linalg import expm as scipy_expm
from scipy.linalg import sqrtm as scipy_sqrtm
from scipy.special import xlogy as scipy_xlogy
Expand Down Expand Up @@ -136,6 +137,59 @@ def concat(self, values: list[np.ndarray], axis: int) -> np.ndarray:
def conj(self, array: np.ndarray) -> np.ndarray:
return np.conj(array)

def convolution(
self,
array: np.ndarray,
filters: np.ndarray,
padding: str = "VALID",
data_format: str | None = None, # pylint: disable=unused-argument
) -> np.ndarray:
"""Performs a 2D convolution operation similar to tf.nn.convolution.
Args:
array: Input array of shape (batch, height, width, channels)
filters: Filter kernel of shape (kernel_height, kernel_width, in_channels, out_channels)
padding: String indicating the padding type ('VALID' or 'SAME')
data_format: Unused, kept for API compatibility
Returns:
np.ndarray: Result of the convolution operation with shape (batch, new_height, new_width, out_channels)
"""
# Extract shapes
batch, _, _, _ = array.shape
kernel_h, kernel_w, _, out_channels = filters.shape

# Reshape filter to 2D for convolution
filter_2d = filters[:, :, 0, 0]

# For SAME padding, calculate padding sizes
if padding == "SAME":
pad_h = (kernel_h - 1) // 2
pad_w = (kernel_w - 1) // 2
array = np.pad(
array[:, :, :, 0], ((0, 0), (pad_h, pad_h), (pad_w, pad_w)), mode="constant"
)
else:
array = array[:, :, :, 0]

# Calculate output dimensions
out_height = array.shape[1] - kernel_h + 1
out_width = array.shape[2] - kernel_w + 1

# Initialize output array
output = np.zeros((batch, out_height, out_width, out_channels))

# Perform convolution for each batch
for b in range(batch):
# Convolve using scipy's convolve2d which is more efficient than np.convolve for 2D
output[b, :, :, 0] = scipy_convolve2d(
array[b],
np.flip(np.flip(filter_2d, 0), 1), # Flip kernel for proper convolution
mode="valid",
)

return output

def cos(self, array: np.ndarray) -> np.ndarray:
return np.cos(array)

Expand Down

0 comments on commit 4c603d0

Please sign in to comment.