Skip to content

Commit

Permalink
start refactoring synthetic data module
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jun 17, 2024
1 parent 270d922 commit 5263c5b
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 62 deletions.
Empty file.
75 changes: 75 additions & 0 deletions src/depiction/tools/simulate/generate_label_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Sequence

import numpy as np
from xarray import DataArray

from depiction.image.multi_channel_image import MultiChannelImage


class GenerateLabelImage:
"""Generates a label image (i.e. multi-channel image without noise) that can be used to generate a MSI
dataset later."""

def __init__(self, image_height: int, image_width: int, n_labels: int, seed: int = 0) -> None:
self._layers = []
self._image_height = image_height
self._image_width = image_width
self._n_labels = n_labels
self._rng = np.random.default_rng(seed)

@property
def shape(self) -> tuple[int, int, int]:
"""Returns the shape of the label image (height, width, n_labels)."""
return self._image_height, self._image_width, self._n_labels

def sample_circles(
self, channel_indices: Sequence[int] | None = None, radius_mean: float = 15, radius_std: float = 5
) -> list[dict[str, float | int]]:
if channel_indices is None:
channel_indices = range(self._n_labels)
circles = []
for i_channel in channel_indices:
center_h = self._rng.uniform(0, self._image_height)
center_w = self._rng.uniform(0, self._image_width)
radius = self._rng.normal(radius_mean, radius_std)
circles.append({"center_h": center_h, "center_w": center_w, "radius": radius, "i_channel": i_channel})
return circles

def add_circles(
self,
circles: list[dict[str, float]],
) -> None:
label_image = np.zeros(self.shape)
for circle in circles:
center_h, center_w = circle["center_h"], circle["center_w"]
radius = circle["radius"]
i_label = circle["i_channel"]
for h in range(self._image_height):
for w in range(self._image_width):
distance = np.sqrt((h - center_h) ** 2 + (w - center_w) ** 2)
if distance < radius:
label_image[h, w, i_label] = 1

self._layers.append(label_image)

def add_stripe_pattern(self, i_channel: int, bandwidth: float, rotation: float = 45.0, phase: float = 0.0) -> None:
def f(x, y):
return np.sin(y / bandwidth * 2 * np.pi + np.radians(phase))

data = np.zeros((self._image_height, self._image_width))
phi = np.radians(rotation)
rot = np.array([[np.cos(phi), -np.sin(phi)], [np.sin(phi), np.cos(phi)]])
for i in range(self._image_height):
for j in range(self._image_width):
i_rot, j_rot = np.dot(rot, [i, j])
data[i, j] = (f(i_rot, j_rot) + 1) / 2

layer = np.zeros(self.shape)
layer[:, :, i_channel] = data
self._layers.append(layer)

def render(self) -> MultiChannelImage:
blended = np.sum(self._layers, axis=0)
data = DataArray(blended, dims=("y", "x", "c"), coords={"c": [f"synthetic_{i}" for i in range(self._n_labels)]})
data.attrs["bg_value"] = 0.0
return MultiChannelImage(data)
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,18 @@
import scipy
from numpy.typing import NDArray
from tqdm import tqdm
from xarray import DataArray

from depiction.estimate_ppm_error import EstimatePPMError
from depiction.image.multi_channel_image import MultiChannelImage
from depiction.persistence import ImzmlWriteFile


class SyntheticMaldiIhcData:
"""Helper that creates synthetic MALDI IHC data."""
class SyntheticMSIDataGenerator:
"""Helper that creates synthetic MSI data."""

def __init__(self, seed: int = 0) -> None:
self.rng = np.random.default_rng(seed)

def generate_label_image_circles(
self,
n_labels: int,
image_height: int,
image_width: int,
radius_mean: float = 15,
radius_std: float = 5,
) -> MultiChannelImage:
"""Generates a label image with a circle for each specified label.
Will generate a full rectangular image.
:param n_labels: The number of labels to generate.
:param image_height: The height of the image.
:param image_width: The width of the image.
:param radius_mean: The mean radius of the circles (drawn from a normal distribution).
:param radius_std: The standard deviation of the radius of the circles (drawn from a normal distribution).
"""
label_image = np.zeros((image_height, image_width, n_labels))

for i_label in range(n_labels):
center_h = self.rng.uniform(0, image_height)
center_w = self.rng.uniform(0, image_width)
radius = self.rng.normal(radius_mean, radius_std)

for h in range(image_height):
for w in range(image_width):
distance = np.sqrt((h - center_h) ** 2 + (w - center_w) ** 2)
if distance < radius:
label_image[h, w, i_label] = 1

data = DataArray(label_image, dims=("y", "x", "c"), coords={"c": [f"synthetic_{i}" for i in range(n_labels)]})
data["bg_value"] = 0.0
return MultiChannelImage(data)

def generate_imzml_for_labels(
self,
write_file: ImzmlWriteFile,
Expand Down Expand Up @@ -113,29 +79,3 @@ def next_peak(peak_mz: float) -> float:

def get_mz_arr(self, min_mass: float, max_mass: float, bin_width_ppm: float) -> NDArray[float]:
return EstimatePPMError.ppm_to_mz_values(bin_width_ppm, mz_min=min_mass, mz_max=max_mass)

@staticmethod
def generate_diagonal_stripe_pattern(
image_height: int, image_width: int, bandwidth: float, rotation: float = 45.0, phase: float = 0.0
) -> NDArray[float]:
"""Generates a diagonal stripe pattern.
Values are in the range [0, 1].
:param image_height: The height of the image.
:param image_width: The width of the image.
:param bandwidth: The bandwidth of the sine wave, i.e. after this many pixels (unrotated) the pattern repeats.
:param rotation: The rotation of the pattern in degrees (by default 45 degrees).
:param phase: The phase of the sine wave, can be used to shift the pattern (periodicity of 360 degrees).
"""

def f(x, y):
return np.sin(y / bandwidth * 2 * np.pi + np.radians(phase))

data = np.zeros((image_height, image_width))
phi = np.radians(rotation)
rot = np.array([[np.cos(phi), -np.sin(phi)], [np.sin(phi), np.cos(phi)]])
for i in range(image_height):
for j in range(image_width):
i_rot, j_rot = np.dot(rot, [i, j])
data[i, j] = (f(i_rot, j_rot) + 1) / 2

return data
Empty file.
67 changes: 67 additions & 0 deletions tests/unit/tools/simulate/test_generate_label_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import pytest

from depiction.tools.simulate.generate_label_image import GenerateLabelImage


@pytest.fixture
def generate() -> GenerateLabelImage:
return GenerateLabelImage(100, 200, 3)


def test_shape(generate) -> None:
assert generate.shape == (100, 200, 3)


def test_sample_circles_when_indices(generate) -> None:
channel_indices = [1, 2]
circles = generate.sample_circles(channel_indices)
assert len(circles) == 2
assert circles[0]["i_channel"] == 1
assert 0 <= circles[0]["center_h"] <= 100
assert 0 <= circles[0]["center_w"] <= 200
assert isinstance(circles[0]["radius"], float)
assert circles[1]["i_channel"] == 2


def test_sample_circles_when_no_indices(generate) -> None:
circles = generate.sample_circles()
assert len(circles) == 3
assert [c["i_channel"] for c in circles] == [0, 1, 2]


def test_add_circles(generate) -> None:
generate.add_circles(
[
{"center_h": 50, "center_w": 100, "radius": 3, "i_channel": 0},
{"center_h": 70, "center_w": 70, "radius": 3, "i_channel": 1},
]
)
assert len(generate._layers) == 1
assert generate._layers[0].shape == (100, 200, 3)
# check center of circle, and then 4 points away from the center
layer = generate._layers[0]
assert layer[50, 100, 0] == 1
assert layer[70, 70, 1] == 1
np.testing.assert_equal(layer[:, :, 2], 0)
assert layer[53, 103, 0] == 0
assert layer[73, 73, 1] == 0


def test_render(generate) -> None:
generate._layers = [
np.array(
[
[[1, 0, 0], [0, 1, 0]],
[[0, 0, 1], [1, 1, 0]],
]
)
]
image = generate.render()
assert image.n_channels == 3
assert (2, 2) == image.dimensions
assert ["synthetic_0", "synthetic_1", "synthetic_2"] == image.channel_names


if __name__ == "__main__":
pytest.main()

0 comments on commit 5263c5b

Please sign in to comment.