diff --git a/src/depiction/tools/simulate/__init__.py b/src/depiction/tools/simulate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/depiction/tools/simulate/generate_label_image.py b/src/depiction/tools/simulate/generate_label_image.py new file mode 100644 index 0000000..997cfaa --- /dev/null +++ b/src/depiction/tools/simulate/generate_label_image.py @@ -0,0 +1,75 @@ +from typing import Sequence + +import numpy as np +from xarray import DataArray + +from depiction.image.multi_channel_image import MultiChannelImage + + +class GenerateLabelImage: + """Generates a label image (i.e. multi-channel image without noise) that can be used to generate a MSI + dataset later.""" + + def __init__(self, image_height: int, image_width: int, n_labels: int, seed: int = 0) -> None: + self._layers = [] + self._image_height = image_height + self._image_width = image_width + self._n_labels = n_labels + self._rng = np.random.default_rng(seed) + + @property + def shape(self) -> tuple[int, int, int]: + """Returns the shape of the label image (height, width, n_labels).""" + return self._image_height, self._image_width, self._n_labels + + def sample_circles( + self, channel_indices: Sequence[int] | None = None, radius_mean: float = 15, radius_std: float = 5 + ) -> list[dict[str, float | int]]: + if channel_indices is None: + channel_indices = range(self._n_labels) + circles = [] + for i_channel in channel_indices: + center_h = self._rng.uniform(0, self._image_height) + center_w = self._rng.uniform(0, self._image_width) + radius = self._rng.normal(radius_mean, radius_std) + circles.append({"center_h": center_h, "center_w": center_w, "radius": radius, "i_channel": i_channel}) + return circles + + def add_circles( + self, + circles: list[dict[str, float]], + ) -> None: + label_image = np.zeros(self.shape) + for circle in circles: + center_h, center_w = circle["center_h"], circle["center_w"] + radius = circle["radius"] + i_label = circle["i_channel"] + for h in range(self._image_height): + for w in range(self._image_width): + distance = np.sqrt((h - center_h) ** 2 + (w - center_w) ** 2) + if distance < radius: + label_image[h, w, i_label] = 1 + + self._layers.append(label_image) + + def add_stripe_pattern(self, i_channel: int, bandwidth: float, rotation: float = 45.0, phase: float = 0.0) -> None: + def f(x, y): + return np.sin(y / bandwidth * 2 * np.pi + np.radians(phase)) + + data = np.zeros((self._image_height, self._image_width)) + phi = np.radians(rotation) + rot = np.array([[np.cos(phi), -np.sin(phi)], [np.sin(phi), np.cos(phi)]]) + for i in range(self._image_height): + for j in range(self._image_width): + i_rot, j_rot = np.dot(rot, [i, j]) + data[i, j] = (f(i_rot, j_rot) + 1) / 2 + + layer = np.zeros(self.shape) + layer[:, :, i_channel] = data + self._layers.append(layer) + + def render(self) -> MultiChannelImage: + blended = np.sum(self._layers, axis=0) + data = DataArray(blended, dims=("y", "x", "c"), coords={"c": [f"synthetic_{i}" for i in range(self._n_labels)]}) + data.attrs["bg_value"] = 0.0 + return MultiChannelImage(data) diff --git a/src/depiction/misc/experimental/synthetic_maldi_ihc_data.py b/src/depiction/tools/simulate/synthetic_msi_data_generator.py similarity index 53% rename from src/depiction/misc/experimental/synthetic_maldi_ihc_data.py rename to src/depiction/tools/simulate/synthetic_msi_data_generator.py index 1f7c56b..2b1addb 100644 --- a/src/depiction/misc/experimental/synthetic_maldi_ihc_data.py +++ b/src/depiction/tools/simulate/synthetic_msi_data_generator.py @@ -4,52 +4,18 @@ import scipy from numpy.typing import NDArray from tqdm import tqdm -from xarray import DataArray from depiction.estimate_ppm_error import EstimatePPMError from depiction.image.multi_channel_image import MultiChannelImage from depiction.persistence import ImzmlWriteFile -class SyntheticMaldiIhcData: - """Helper that creates synthetic MALDI IHC data.""" +class SyntheticMSIDataGenerator: + """Helper that creates synthetic MSI data.""" def __init__(self, seed: int = 0) -> None: self.rng = np.random.default_rng(seed) - def generate_label_image_circles( - self, - n_labels: int, - image_height: int, - image_width: int, - radius_mean: float = 15, - radius_std: float = 5, - ) -> MultiChannelImage: - """Generates a label image with a circle for each specified label. - Will generate a full rectangular image. - :param n_labels: The number of labels to generate. - :param image_height: The height of the image. - :param image_width: The width of the image. - :param radius_mean: The mean radius of the circles (drawn from a normal distribution). - :param radius_std: The standard deviation of the radius of the circles (drawn from a normal distribution). - """ - label_image = np.zeros((image_height, image_width, n_labels)) - - for i_label in range(n_labels): - center_h = self.rng.uniform(0, image_height) - center_w = self.rng.uniform(0, image_width) - radius = self.rng.normal(radius_mean, radius_std) - - for h in range(image_height): - for w in range(image_width): - distance = np.sqrt((h - center_h) ** 2 + (w - center_w) ** 2) - if distance < radius: - label_image[h, w, i_label] = 1 - - data = DataArray(label_image, dims=("y", "x", "c"), coords={"c": [f"synthetic_{i}" for i in range(n_labels)]}) - data["bg_value"] = 0.0 - return MultiChannelImage(data) - def generate_imzml_for_labels( self, write_file: ImzmlWriteFile, @@ -113,29 +79,3 @@ def next_peak(peak_mz: float) -> float: def get_mz_arr(self, min_mass: float, max_mass: float, bin_width_ppm: float) -> NDArray[float]: return EstimatePPMError.ppm_to_mz_values(bin_width_ppm, mz_min=min_mass, mz_max=max_mass) - - @staticmethod - def generate_diagonal_stripe_pattern( - image_height: int, image_width: int, bandwidth: float, rotation: float = 45.0, phase: float = 0.0 - ) -> NDArray[float]: - """Generates a diagonal stripe pattern. - Values are in the range [0, 1]. - :param image_height: The height of the image. - :param image_width: The width of the image. - :param bandwidth: The bandwidth of the sine wave, i.e. after this many pixels (unrotated) the pattern repeats. - :param rotation: The rotation of the pattern in degrees (by default 45 degrees). - :param phase: The phase of the sine wave, can be used to shift the pattern (periodicity of 360 degrees). - """ - - def f(x, y): - return np.sin(y / bandwidth * 2 * np.pi + np.radians(phase)) - - data = np.zeros((image_height, image_width)) - phi = np.radians(rotation) - rot = np.array([[np.cos(phi), -np.sin(phi)], [np.sin(phi), np.cos(phi)]]) - for i in range(image_height): - for j in range(image_width): - i_rot, j_rot = np.dot(rot, [i, j]) - data[i, j] = (f(i_rot, j_rot) + 1) / 2 - - return data diff --git a/tests/unit/tools/simulate/__init__.py b/tests/unit/tools/simulate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/tools/simulate/test_generate_label_image.py b/tests/unit/tools/simulate/test_generate_label_image.py new file mode 100644 index 0000000..3fda0dd --- /dev/null +++ b/tests/unit/tools/simulate/test_generate_label_image.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest + +from depiction.tools.simulate.generate_label_image import GenerateLabelImage + + +@pytest.fixture +def generate() -> GenerateLabelImage: + return GenerateLabelImage(100, 200, 3) + + +def test_shape(generate) -> None: + assert generate.shape == (100, 200, 3) + + +def test_sample_circles_when_indices(generate) -> None: + channel_indices = [1, 2] + circles = generate.sample_circles(channel_indices) + assert len(circles) == 2 + assert circles[0]["i_channel"] == 1 + assert 0 <= circles[0]["center_h"] <= 100 + assert 0 <= circles[0]["center_w"] <= 200 + assert isinstance(circles[0]["radius"], float) + assert circles[1]["i_channel"] == 2 + + +def test_sample_circles_when_no_indices(generate) -> None: + circles = generate.sample_circles() + assert len(circles) == 3 + assert [c["i_channel"] for c in circles] == [0, 1, 2] + + +def test_add_circles(generate) -> None: + generate.add_circles( + [ + {"center_h": 50, "center_w": 100, "radius": 3, "i_channel": 0}, + {"center_h": 70, "center_w": 70, "radius": 3, "i_channel": 1}, + ] + ) + assert len(generate._layers) == 1 + assert generate._layers[0].shape == (100, 200, 3) + # check center of circle, and then 4 points away from the center + layer = generate._layers[0] + assert layer[50, 100, 0] == 1 + assert layer[70, 70, 1] == 1 + np.testing.assert_equal(layer[:, :, 2], 0) + assert layer[53, 103, 0] == 0 + assert layer[73, 73, 1] == 0 + + +def test_render(generate) -> None: + generate._layers = [ + np.array( + [ + [[1, 0, 0], [0, 1, 0]], + [[0, 0, 1], [1, 1, 0]], + ] + ) + ] + image = generate.render() + assert image.n_channels == 3 + assert (2, 2) == image.dimensions + assert ["synthetic_0", "synthetic_1", "synthetic_2"] == image.channel_names + + +if __name__ == "__main__": + pytest.main()