-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
start refactoring synthetic data module
- Loading branch information
1 parent
270d922
commit 5263c5b
Showing
5 changed files
with
144 additions
and
62 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import Sequence | ||
|
||
import numpy as np | ||
from xarray import DataArray | ||
|
||
from depiction.image.multi_channel_image import MultiChannelImage | ||
|
||
|
||
class GenerateLabelImage: | ||
"""Generates a label image (i.e. multi-channel image without noise) that can be used to generate a MSI | ||
dataset later.""" | ||
|
||
def __init__(self, image_height: int, image_width: int, n_labels: int, seed: int = 0) -> None: | ||
self._layers = [] | ||
self._image_height = image_height | ||
self._image_width = image_width | ||
self._n_labels = n_labels | ||
self._rng = np.random.default_rng(seed) | ||
|
||
@property | ||
def shape(self) -> tuple[int, int, int]: | ||
"""Returns the shape of the label image (height, width, n_labels).""" | ||
return self._image_height, self._image_width, self._n_labels | ||
|
||
def sample_circles( | ||
self, channel_indices: Sequence[int] | None = None, radius_mean: float = 15, radius_std: float = 5 | ||
) -> list[dict[str, float | int]]: | ||
if channel_indices is None: | ||
channel_indices = range(self._n_labels) | ||
circles = [] | ||
for i_channel in channel_indices: | ||
center_h = self._rng.uniform(0, self._image_height) | ||
center_w = self._rng.uniform(0, self._image_width) | ||
radius = self._rng.normal(radius_mean, radius_std) | ||
circles.append({"center_h": center_h, "center_w": center_w, "radius": radius, "i_channel": i_channel}) | ||
return circles | ||
|
||
def add_circles( | ||
self, | ||
circles: list[dict[str, float]], | ||
) -> None: | ||
label_image = np.zeros(self.shape) | ||
for circle in circles: | ||
center_h, center_w = circle["center_h"], circle["center_w"] | ||
radius = circle["radius"] | ||
i_label = circle["i_channel"] | ||
for h in range(self._image_height): | ||
for w in range(self._image_width): | ||
distance = np.sqrt((h - center_h) ** 2 + (w - center_w) ** 2) | ||
if distance < radius: | ||
label_image[h, w, i_label] = 1 | ||
|
||
self._layers.append(label_image) | ||
|
||
def add_stripe_pattern(self, i_channel: int, bandwidth: float, rotation: float = 45.0, phase: float = 0.0) -> None: | ||
def f(x, y): | ||
return np.sin(y / bandwidth * 2 * np.pi + np.radians(phase)) | ||
|
||
data = np.zeros((self._image_height, self._image_width)) | ||
phi = np.radians(rotation) | ||
rot = np.array([[np.cos(phi), -np.sin(phi)], [np.sin(phi), np.cos(phi)]]) | ||
for i in range(self._image_height): | ||
for j in range(self._image_width): | ||
i_rot, j_rot = np.dot(rot, [i, j]) | ||
data[i, j] = (f(i_rot, j_rot) + 1) / 2 | ||
|
||
layer = np.zeros(self.shape) | ||
layer[:, :, i_channel] = data | ||
self._layers.append(layer) | ||
|
||
def render(self) -> MultiChannelImage: | ||
blended = np.sum(self._layers, axis=0) | ||
data = DataArray(blended, dims=("y", "x", "c"), coords={"c": [f"synthetic_{i}" for i in range(self._n_labels)]}) | ||
data.attrs["bg_value"] = 0.0 | ||
return MultiChannelImage(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from depiction.tools.simulate.generate_label_image import GenerateLabelImage | ||
|
||
|
||
@pytest.fixture | ||
def generate() -> GenerateLabelImage: | ||
return GenerateLabelImage(100, 200, 3) | ||
|
||
|
||
def test_shape(generate) -> None: | ||
assert generate.shape == (100, 200, 3) | ||
|
||
|
||
def test_sample_circles_when_indices(generate) -> None: | ||
channel_indices = [1, 2] | ||
circles = generate.sample_circles(channel_indices) | ||
assert len(circles) == 2 | ||
assert circles[0]["i_channel"] == 1 | ||
assert 0 <= circles[0]["center_h"] <= 100 | ||
assert 0 <= circles[0]["center_w"] <= 200 | ||
assert isinstance(circles[0]["radius"], float) | ||
assert circles[1]["i_channel"] == 2 | ||
|
||
|
||
def test_sample_circles_when_no_indices(generate) -> None: | ||
circles = generate.sample_circles() | ||
assert len(circles) == 3 | ||
assert [c["i_channel"] for c in circles] == [0, 1, 2] | ||
|
||
|
||
def test_add_circles(generate) -> None: | ||
generate.add_circles( | ||
[ | ||
{"center_h": 50, "center_w": 100, "radius": 3, "i_channel": 0}, | ||
{"center_h": 70, "center_w": 70, "radius": 3, "i_channel": 1}, | ||
] | ||
) | ||
assert len(generate._layers) == 1 | ||
assert generate._layers[0].shape == (100, 200, 3) | ||
# check center of circle, and then 4 points away from the center | ||
layer = generate._layers[0] | ||
assert layer[50, 100, 0] == 1 | ||
assert layer[70, 70, 1] == 1 | ||
np.testing.assert_equal(layer[:, :, 2], 0) | ||
assert layer[53, 103, 0] == 0 | ||
assert layer[73, 73, 1] == 0 | ||
|
||
|
||
def test_render(generate) -> None: | ||
generate._layers = [ | ||
np.array( | ||
[ | ||
[[1, 0, 0], [0, 1, 0]], | ||
[[0, 0, 1], [1, 1, 0]], | ||
] | ||
) | ||
] | ||
image = generate.render() | ||
assert image.n_channels == 3 | ||
assert (2, 2) == image.dimensions | ||
assert ["synthetic_0", "synthetic_1", "synthetic_2"] == image.channel_names | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main() |