Skip to content

Commit

Permalink
optimized tissue segmentation
Browse files Browse the repository at this point in the history
Implemented a more robust tissue segmentation where adipose tissue is mostly ignored for the stitch. Adipose tissue tends to be highly deformable and should not influence the stitch too much and is therefore filtered out with this commit.
  • Loading branch information
dnschouten committed May 26, 2023
1 parent 3a4882b commit 4274a01
Show file tree
Hide file tree
Showing 6 changed files with 421 additions and 129 deletions.
226 changes: 163 additions & 63 deletions src/assembly_utils/pairwise_alignment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, kwargs):
self.fragment_name = kwargs["fragment_name"]
self.num_fragments = kwargs["n_fragments"]
self.fragment_name_idx = int(self.fragment_name.lstrip("fragment").rstrip(".png")) - 1
self.original_name = self.all_fragment_names[self.fragment_name_idx].split(".")[0]

self.save_dir = kwargs["save_dir"]
self.data_dir = kwargs["data_dir"]
Expand All @@ -50,6 +51,8 @@ def __init__(self, kwargs):
for line in data:
self.config_dict[line.split(":")[0]] = line.split(":")[-1].rstrip("\n")

self.location = self.config_dict[self.original_name]

if self.num_fragments == 4:
self.classifier = kwargs["fragment_classifier"]

Expand Down Expand Up @@ -233,14 +236,26 @@ def get_stitch_edges(self):
bbox_center = np.mean(bbox_corners, axis=0)
bbox_corner_dist = bbox_center - bbox_corners

# Expand the bbox in size but with same center point. Using this larger bbox
# for determining stitch edges has shown to be a bit more robust against outliers
# in the contour.
expansion = bbox_center / 2
expand_direction = np.array([list(i < 0) for i in bbox_corner_dist])
bbox_corners_expansion = (expand_direction == True) * expansion + (
expand_direction == False
) * -expansion
# Expand the bbox for a bit more robust point selection on the contour. We expand
# the bbox in only one direction or uniformly, depending on whether we are dealing
# with 2 or 4 fragments. These choices were based on empirical observations.
if hasattr(self, "location"):
if self.location in ["left", "right"]:
bbox_corners_expansion = -np.vstack([bbox_corner_dist[:, 0]*5, np.zeros_like(
bbox_corner_dist[:, 0])]).T

elif self.location in ["top", "bottom"]:
bbox_corners_expansion = -np.vstack(np.zeros_like(bbox_corner_dist[:, 0]),
[bbox_corner_dist[:, 1]*5]).T

# Case of 4 fragments expand uniformly.
else:
expansion = bbox_center / 2
expand_direction = np.array([list(i < 0) for i in bbox_corner_dist])
bbox_corners_expansion = (expand_direction == True) * expansion + (
expand_direction == False
) * -expansion

new_bbox_corners = bbox_corners + bbox_corners_expansion

### Step 3 ###
Expand All @@ -259,23 +274,35 @@ def get_stitch_edges(self):
# Step 4a - scenario with 2 fragments
if self.num_fragments == 2:

# Compute distance from contour corners to surrounding bbox
dist_cnt_corner_to_bbox = np.min(
distance.cdist(self.cnt_corners, new_bbox_corners), axis=1
)
dist_cnt_corner_to_bbox_loop = np.hstack([dist_cnt_corner_to_bbox] * 2)
if hasattr(self, "location"):

# Get the pair of contour corners with largest distance to bounding box, this should
# be the outer point of the contour. Stitch edge is then located opposite.
dist_per_cnt_corner_pair = [
dist_cnt_corner_to_bbox_loop[i] ** 2 + dist_cnt_corner_to_bbox_loop[i + 1] ** 2
for i in range(4)
]
max_dist_corner_idx = np.argmax(dist_per_cnt_corner_pair)
if self.location == "left":
start_corner = self.cnt_corners_loop[ul_cnt_corner_idx + 1]
end_corner = self.cnt_corners_loop[ul_cnt_corner_idx + 2]

elif self.location == "right":
start_corner = self.cnt_corners_loop[ul_cnt_corner_idx + 3]
end_corner = self.cnt_corners_loop[ul_cnt_corner_idx]

else:
# Compute distance from contour corners to surrounding bbox
dist_cnt_corner_to_bbox = np.min(
distance.cdist(self.cnt_corners, new_bbox_corners), axis=1
)
dist_cnt_corner_to_bbox_loop = np.hstack([dist_cnt_corner_to_bbox] * 2)

# Get the pair of contour corners with largest distance to bounding box, this should
# be the outer point of the contour. Stitch edge is then located opposite.
dist_per_cnt_corner_pair = [
dist_cnt_corner_to_bbox_loop[i] ** 2 + dist_cnt_corner_to_bbox_loop[i + 1] ** 2
for i in range(4)
]
max_dist_corner_idx = np.argmax(dist_per_cnt_corner_pair)

# Get location and indices of these contour corners
start_corner = self.cnt_corners_loop[max_dist_corner_idx + 2]
end_corner = self.cnt_corners_loop[max_dist_corner_idx + 3]

# Get location and indices of these contour corners
start_corner = self.cnt_corners_loop[max_dist_corner_idx + 2]
end_corner = self.cnt_corners_loop[max_dist_corner_idx + 3]
start_idx = np.argmax((self.cnt_rdp == start_corner).all(axis=1))
end_idx = np.argmax((self.cnt_rdp == end_corner).all(axis=1)) + 1

Expand Down Expand Up @@ -401,10 +428,12 @@ def save_landmark_points(self):
Method to save some automatically detected landmark points. These can be used
later to compute the residual registration mismatch.
Approximate steps:
1. Get lowres version of mask
2. Process it with floodfilling and connected component analysis
3. Compute simplified contour and stitch edges
4.
1. Load images and masks and get lowres version
2. Get otsu mask and tissue segmentation mask
3. Process tissue segmentation mask
4. Combine masks
5. Compute simplified contour and stitch edges
6.
"""

### STEP 1 ###
Expand Down Expand Up @@ -436,55 +465,104 @@ def save_landmark_points(self):
mask2raw_scaling = raw_image_dims[0] / mask_ds_level_dims[0]

# Get downsampled version of mask
lowres_mask = self.raw_mask.getUCharPatch(
lowres_image = self.raw_image.getUCharPatch(
startX=0,
startY=0,
width=image_ds_level_dims[0],
height=image_ds_level_dims[1],
level=int(self.landmark_level)
)

### STEP 2 ###
lowres_image_hsv = cv2.cvtColor(lowres_image, cv2.COLOR_RGB2HSV)
lowres_image_hsv = cv2.medianBlur(lowres_image_hsv[:, :, 1], 7)
_, otsu_mask = cv2.threshold(lowres_image_hsv, 0, 255, cv2.THRESH_OTSU +
cv2.THRESH_BINARY)
otsu_mask = (otsu_mask / np.max(otsu_mask)).astype("uint8")

# Postprocess the mask a bit
kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(8, 8))
otsu_mask = cv2.morphologyEx(
src=otsu_mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=3
)

tissue_mask = self.raw_mask.getUCharPatch(
startX=0,
startY=0,
width=mask_ds_level_dims[0],
height=mask_ds_level_dims[1],
level=int(mask_ds_level)
)

### STEP 2 ###
# Process mask
temp_pad = int(0.05 * lowres_mask.shape[0])
lowres_mask = np.pad(
np.squeeze(lowres_mask),
### STEP 3 ###
# Process tissue mask
temp_pad = int(0.05 * tissue_mask.shape[0])
tissue_mask = np.pad(
np.squeeze(tissue_mask),
[[temp_pad, temp_pad], [temp_pad, temp_pad]],
mode="constant",
constant_values=0,
)

# Get largest component
num_labels, labeled_mask, stats, _ = cv2.connectedComponentsWithStats(
lowres_mask, connectivity=8
tissue_mask, connectivity=8
)
largest_cc_label = np.argmax(stats[1:, -1]) + 1
lowres_mask = ((labeled_mask == largest_cc_label) * 255).astype("uint8")
tissue_mask = ((labeled_mask == largest_cc_label) * 255).astype("uint8")

# Close some small holes
strel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
lowres_mask = cv2.morphologyEx(lowres_mask, cv2.MORPH_CLOSE, strel, iterations=3)
tissue_mask = cv2.morphologyEx(tissue_mask, cv2.MORPH_CLOSE, strel, iterations=3)

# Get floodfilled background
seedpoint = (0, 0)
floodfill_mask = np.zeros(
(lowres_mask.shape[0] + 2, lowres_mask.shape[1] + 2)
(tissue_mask.shape[0] + 2, tissue_mask.shape[1] + 2)
)
floodfill_mask = floodfill_mask.astype("uint8")
_, _, lowres_mask, _ = cv2.floodFill(
lowres_mask,
_, _, tissue_mask, _ = cv2.floodFill(
tissue_mask,
floodfill_mask,
seedpoint,
255
)
lowres_mask = (
1 - lowres_mask[temp_pad + 1: -(temp_pad + 1), temp_pad + 1: -(temp_pad + 1)]
tissue_mask = (
1 - tissue_mask[temp_pad + 1: -(temp_pad + 1), temp_pad + 1: -(temp_pad + 1)]
)

### STEP 3 ###
### STEP 4 ###
# Combine masks
final_mask = otsu_mask * tissue_mask

# Postprocess similar to tissue segmentation mask. Get largest cc and floodfill.
num_labels, labeled_im, stats, _ = cv2.connectedComponentsWithStats(
final_mask, connectivity=8
)
assert num_labels > 1, "mask is empty"
largest_cc_label = np.argmax(stats[1:, -1]) + 1
final_mask = ((labeled_im == largest_cc_label) * 255).astype("uint8")

# Flood fill
offset = 5
final_mask = np.pad(
final_mask,
[[offset, offset], [offset, offset]],
mode="constant",
constant_values=0
)

seedpoint = (0, 0)
floodfill_mask = np.zeros(
(final_mask.shape[0] + 2, final_mask.shape[1] + 2)).astype("uint8")
_, _, final_mask, _ = cv2.floodFill(final_mask, floodfill_mask, seedpoint, 255)
final_mask = final_mask[1 + offset:-1 - offset, 1 + offset:-1 - offset]
final_mask = 1 - final_mask

### STEP 5 ###
# Get contour
cnt, _ = cv2.findContours(
lowres_mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
final_mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
)
cnt = np.squeeze(max(cnt, key=cv2.contourArea))[::-1]

Expand All @@ -494,11 +572,22 @@ def save_landmark_points(self):
bbox_center = np.mean(bbox_points, axis=0)
bbox_corner_dist = bbox_center - bbox_points

expansion = bbox_center / 10
expand_direction = np.array([list(i < 0) for i in bbox_corner_dist])
bbox_corners_expansion = (expand_direction == True) * expansion + (
# Again expand bbox
if hasattr(self, "location"):
if self.location in ["left", "right"]:
bbox_corners_expansion = np.vstack([bbox_corner_dist[:, 0]*5, np.zeros_like(
bbox_corner_dist[:, 0])]).T

elif self.location in ["top", "bottom"]:
bbox_corners_expansion = np.vstack(np.zeros_like(bbox_corner_dist[:, 0]), [bbox_corner_dist[:, 1]*5]).T

# Case of 4 fragments expand uniformly.
else:
expansion = bbox_center / 2
expand_direction = np.array([list(i < 0) for i in bbox_corner_dist])
bbox_corners_expansion = (expand_direction == True) * expansion + (
expand_direction == False
) * -expansion
) * -expansion
new_bbox_points = bbox_points + bbox_corners_expansion

# Get closest point on contour as seen from enlarged bbox
Expand All @@ -513,23 +602,34 @@ def save_landmark_points(self):
cnt_fragments = []
if self.num_fragments == 2:

# Compute distance from contour corners to surrounding bbox
dist_cnt_corner_to_bbox = np.min(
distance.cdist(cnt_corners, new_bbox_points), axis=1
)
dist_cnt_corner_to_bbox_loop = np.hstack([dist_cnt_corner_to_bbox] * 2)
if hasattr(self, "location"):
if self.location == "left":
start_corner = cnt_corners_loop[ul_cnt_corner_idx + 1]
end_corner = cnt_corners_loop[ul_cnt_corner_idx + 2]

# Get the pair of contour corners with largest distance to bounding box, this should
# be the outer point of the contour. Stitch edge is then located opposite.
dist_per_cnt_corner_pair = [
dist_cnt_corner_to_bbox_loop[i] ** 2 + dist_cnt_corner_to_bbox_loop[i + 1] ** 2
for i in range(4)
]
max_dist_corner_idx = np.argmax(dist_per_cnt_corner_pair)
elif self.location == "right":
start_corner = cnt_corners_loop[ul_cnt_corner_idx + 3]
end_corner = cnt_corners_loop[ul_cnt_corner_idx]

else:
# Compute distance from contour corners to surrounding bbox
dist_cnt_corner_to_bbox = np.min(
distance.cdist(cnt_corners, new_bbox_points), axis=1
)
dist_cnt_corner_to_bbox_loop = np.hstack([dist_cnt_corner_to_bbox] * 2)

# Get the pair of contour corners with largest distance to bounding box, this should
# be the outer point of the contour. Stitch edge is then located opposite.
dist_per_cnt_corner_pair = [
dist_cnt_corner_to_bbox_loop[i] ** 2 + dist_cnt_corner_to_bbox_loop[i + 1] ** 2
for i in range(4)
]
max_dist_corner_idx = np.argmax(dist_per_cnt_corner_pair)

# Get location and indices of these contour corners
start_corner = cnt_corners_loop[max_dist_corner_idx + 2]
end_corner = cnt_corners_loop[max_dist_corner_idx + 3]

# Get location and indices of these contour corners
start_corner = cnt_corners_loop[max_dist_corner_idx + 2]
end_corner = cnt_corners_loop[max_dist_corner_idx + 3]
start_idx = np.argmax((cnt == start_corner).all(axis=1))
end_idx = np.argmax((cnt == end_corner).all(axis=1)) + 1

Expand All @@ -541,7 +641,7 @@ def save_landmark_points(self):

# Sanity check
plt.figure()
plt.imshow(lowres_mask)
plt.imshow(final_mask)
plt.plot(cnt_fragments[0][:, 0], cnt_fragments[0][:, 1], c="r")
plt.show()

Expand Down
6 changes: 3 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def load_parameter_configuration(data_dir, save_dir, output_res):
parameters["save_dir"] = save_dir
parameters["patient_idx"] = data_dir.name
parameters["output_res"] = output_res
parameters["fragment_names"] = [
parameters["fragment_names"] = sorted([
i.name for i in data_dir.joinpath("raw_images").iterdir() if not i.is_dir()
]
])
parameters["n_fragments"] = len(parameters["fragment_names"])
parameters["resolution_scaling"] = [
i / parameters["resolutions"][0] for i in parameters["resolutions"]
Expand Down Expand Up @@ -85,7 +85,7 @@ def collect_arguments():

# Parse arguments
parser = argparse.ArgumentParser(
description="Stitch prostate histopathology images into a pseudo whole-mount image"
description="Stitch histopathology images into a pseudo whole-mount image"
)
parser.add_argument(
"--datadir", required=True, type=pathlib.Path, help="Path to the case to stitch"
Expand Down
Loading

0 comments on commit 4274a01

Please sign in to comment.