From 4274a01621f649fb73ac4763fa61521ee04d9d39 Mon Sep 17 00:00:00 2001 From: dnschouten Date: Fri, 26 May 2023 13:40:26 +0200 Subject: [PATCH] optimized tissue segmentation Implemented a more robust tissue segmentation where adipose tissue is mostly ignored for the stitch. Adipose tissue tends to be highly deformable and should not influence the stitch too much and is therefore filtered out with this commit. --- .../pairwise_alignment_utils.py | 226 +++++++++++++----- src/main.py | 6 +- src/preprocessing_utils/prepare_data.py | 87 +++++-- src/pythostitcher_utils/fragment_class.py | 45 +++- src/pythostitcher_utils/full_resolution.py | 179 +++++++++++--- src/pythostitcher_utils/gradient_blending.py | 7 +- 6 files changed, 421 insertions(+), 129 deletions(-) diff --git a/src/assembly_utils/pairwise_alignment_utils.py b/src/assembly_utils/pairwise_alignment_utils.py index ecedadf..f371bfb 100644 --- a/src/assembly_utils/pairwise_alignment_utils.py +++ b/src/assembly_utils/pairwise_alignment_utils.py @@ -30,6 +30,7 @@ def __init__(self, kwargs): self.fragment_name = kwargs["fragment_name"] self.num_fragments = kwargs["n_fragments"] self.fragment_name_idx = int(self.fragment_name.lstrip("fragment").rstrip(".png")) - 1 + self.original_name = self.all_fragment_names[self.fragment_name_idx].split(".")[0] self.save_dir = kwargs["save_dir"] self.data_dir = kwargs["data_dir"] @@ -50,6 +51,8 @@ def __init__(self, kwargs): for line in data: self.config_dict[line.split(":")[0]] = line.split(":")[-1].rstrip("\n") + self.location = self.config_dict[self.original_name] + if self.num_fragments == 4: self.classifier = kwargs["fragment_classifier"] @@ -233,14 +236,26 @@ def get_stitch_edges(self): bbox_center = np.mean(bbox_corners, axis=0) bbox_corner_dist = bbox_center - bbox_corners - # Expand the bbox in size but with same center point. Using this larger bbox - # for determining stitch edges has shown to be a bit more robust against outliers - # in the contour. - expansion = bbox_center / 2 - expand_direction = np.array([list(i < 0) for i in bbox_corner_dist]) - bbox_corners_expansion = (expand_direction == True) * expansion + ( - expand_direction == False - ) * -expansion + # Expand the bbox for a bit more robust point selection on the contour. We expand + # the bbox in only one direction or uniformly, depending on whether we are dealing + # with 2 or 4 fragments. These choices were based on empirical observations. + if hasattr(self, "location"): + if self.location in ["left", "right"]: + bbox_corners_expansion = -np.vstack([bbox_corner_dist[:, 0]*5, np.zeros_like( + bbox_corner_dist[:, 0])]).T + + elif self.location in ["top", "bottom"]: + bbox_corners_expansion = -np.vstack(np.zeros_like(bbox_corner_dist[:, 0]), + [bbox_corner_dist[:, 1]*5]).T + + # Case of 4 fragments expand uniformly. + else: + expansion = bbox_center / 2 + expand_direction = np.array([list(i < 0) for i in bbox_corner_dist]) + bbox_corners_expansion = (expand_direction == True) * expansion + ( + expand_direction == False + ) * -expansion + new_bbox_corners = bbox_corners + bbox_corners_expansion ### Step 3 ### @@ -259,23 +274,35 @@ def get_stitch_edges(self): # Step 4a - scenario with 2 fragments if self.num_fragments == 2: - # Compute distance from contour corners to surrounding bbox - dist_cnt_corner_to_bbox = np.min( - distance.cdist(self.cnt_corners, new_bbox_corners), axis=1 - ) - dist_cnt_corner_to_bbox_loop = np.hstack([dist_cnt_corner_to_bbox] * 2) + if hasattr(self, "location"): - # Get the pair of contour corners with largest distance to bounding box, this should - # be the outer point of the contour. Stitch edge is then located opposite. - dist_per_cnt_corner_pair = [ - dist_cnt_corner_to_bbox_loop[i] ** 2 + dist_cnt_corner_to_bbox_loop[i + 1] ** 2 - for i in range(4) - ] - max_dist_corner_idx = np.argmax(dist_per_cnt_corner_pair) + if self.location == "left": + start_corner = self.cnt_corners_loop[ul_cnt_corner_idx + 1] + end_corner = self.cnt_corners_loop[ul_cnt_corner_idx + 2] + + elif self.location == "right": + start_corner = self.cnt_corners_loop[ul_cnt_corner_idx + 3] + end_corner = self.cnt_corners_loop[ul_cnt_corner_idx] + + else: + # Compute distance from contour corners to surrounding bbox + dist_cnt_corner_to_bbox = np.min( + distance.cdist(self.cnt_corners, new_bbox_corners), axis=1 + ) + dist_cnt_corner_to_bbox_loop = np.hstack([dist_cnt_corner_to_bbox] * 2) + + # Get the pair of contour corners with largest distance to bounding box, this should + # be the outer point of the contour. Stitch edge is then located opposite. + dist_per_cnt_corner_pair = [ + dist_cnt_corner_to_bbox_loop[i] ** 2 + dist_cnt_corner_to_bbox_loop[i + 1] ** 2 + for i in range(4) + ] + max_dist_corner_idx = np.argmax(dist_per_cnt_corner_pair) + + # Get location and indices of these contour corners + start_corner = self.cnt_corners_loop[max_dist_corner_idx + 2] + end_corner = self.cnt_corners_loop[max_dist_corner_idx + 3] - # Get location and indices of these contour corners - start_corner = self.cnt_corners_loop[max_dist_corner_idx + 2] - end_corner = self.cnt_corners_loop[max_dist_corner_idx + 3] start_idx = np.argmax((self.cnt_rdp == start_corner).all(axis=1)) end_idx = np.argmax((self.cnt_rdp == end_corner).all(axis=1)) + 1 @@ -401,10 +428,12 @@ def save_landmark_points(self): Method to save some automatically detected landmark points. These can be used later to compute the residual registration mismatch. Approximate steps: - 1. Get lowres version of mask - 2. Process it with floodfilling and connected component analysis - 3. Compute simplified contour and stitch edges - 4. + 1. Load images and masks and get lowres version + 2. Get otsu mask and tissue segmentation mask + 3. Process tissue segmentation mask + 4. Combine masks + 5. Compute simplified contour and stitch edges + 6. """ ### STEP 1 ### @@ -436,7 +465,28 @@ def save_landmark_points(self): mask2raw_scaling = raw_image_dims[0] / mask_ds_level_dims[0] # Get downsampled version of mask - lowres_mask = self.raw_mask.getUCharPatch( + lowres_image = self.raw_image.getUCharPatch( + startX=0, + startY=0, + width=image_ds_level_dims[0], + height=image_ds_level_dims[1], + level=int(self.landmark_level) + ) + + ### STEP 2 ### + lowres_image_hsv = cv2.cvtColor(lowres_image, cv2.COLOR_RGB2HSV) + lowres_image_hsv = cv2.medianBlur(lowres_image_hsv[:, :, 1], 7) + _, otsu_mask = cv2.threshold(lowres_image_hsv, 0, 255, cv2.THRESH_OTSU + + cv2.THRESH_BINARY) + otsu_mask = (otsu_mask / np.max(otsu_mask)).astype("uint8") + + # Postprocess the mask a bit + kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(8, 8)) + otsu_mask = cv2.morphologyEx( + src=otsu_mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=3 + ) + + tissue_mask = self.raw_mask.getUCharPatch( startX=0, startY=0, width=mask_ds_level_dims[0], @@ -444,11 +494,11 @@ def save_landmark_points(self): level=int(mask_ds_level) ) - ### STEP 2 ### - # Process mask - temp_pad = int(0.05 * lowres_mask.shape[0]) - lowres_mask = np.pad( - np.squeeze(lowres_mask), + ### STEP 3 ### + # Process tissue mask + temp_pad = int(0.05 * tissue_mask.shape[0]) + tissue_mask = np.pad( + np.squeeze(tissue_mask), [[temp_pad, temp_pad], [temp_pad, temp_pad]], mode="constant", constant_values=0, @@ -456,35 +506,63 @@ def save_landmark_points(self): # Get largest component num_labels, labeled_mask, stats, _ = cv2.connectedComponentsWithStats( - lowres_mask, connectivity=8 + tissue_mask, connectivity=8 ) largest_cc_label = np.argmax(stats[1:, -1]) + 1 - lowres_mask = ((labeled_mask == largest_cc_label) * 255).astype("uint8") + tissue_mask = ((labeled_mask == largest_cc_label) * 255).astype("uint8") # Close some small holes strel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) - lowres_mask = cv2.morphologyEx(lowres_mask, cv2.MORPH_CLOSE, strel, iterations=3) + tissue_mask = cv2.morphologyEx(tissue_mask, cv2.MORPH_CLOSE, strel, iterations=3) # Get floodfilled background seedpoint = (0, 0) floodfill_mask = np.zeros( - (lowres_mask.shape[0] + 2, lowres_mask.shape[1] + 2) + (tissue_mask.shape[0] + 2, tissue_mask.shape[1] + 2) ) floodfill_mask = floodfill_mask.astype("uint8") - _, _, lowres_mask, _ = cv2.floodFill( - lowres_mask, + _, _, tissue_mask, _ = cv2.floodFill( + tissue_mask, floodfill_mask, seedpoint, 255 ) - lowres_mask = ( - 1 - lowres_mask[temp_pad + 1: -(temp_pad + 1), temp_pad + 1: -(temp_pad + 1)] + tissue_mask = ( + 1 - tissue_mask[temp_pad + 1: -(temp_pad + 1), temp_pad + 1: -(temp_pad + 1)] ) - ### STEP 3 ### + ### STEP 4 ### + # Combine masks + final_mask = otsu_mask * tissue_mask + + # Postprocess similar to tissue segmentation mask. Get largest cc and floodfill. + num_labels, labeled_im, stats, _ = cv2.connectedComponentsWithStats( + final_mask, connectivity=8 + ) + assert num_labels > 1, "mask is empty" + largest_cc_label = np.argmax(stats[1:, -1]) + 1 + final_mask = ((labeled_im == largest_cc_label) * 255).astype("uint8") + + # Flood fill + offset = 5 + final_mask = np.pad( + final_mask, + [[offset, offset], [offset, offset]], + mode="constant", + constant_values=0 + ) + + seedpoint = (0, 0) + floodfill_mask = np.zeros( + (final_mask.shape[0] + 2, final_mask.shape[1] + 2)).astype("uint8") + _, _, final_mask, _ = cv2.floodFill(final_mask, floodfill_mask, seedpoint, 255) + final_mask = final_mask[1 + offset:-1 - offset, 1 + offset:-1 - offset] + final_mask = 1 - final_mask + + ### STEP 5 ### # Get contour cnt, _ = cv2.findContours( - lowres_mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE + final_mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE ) cnt = np.squeeze(max(cnt, key=cv2.contourArea))[::-1] @@ -494,11 +572,22 @@ def save_landmark_points(self): bbox_center = np.mean(bbox_points, axis=0) bbox_corner_dist = bbox_center - bbox_points - expansion = bbox_center / 10 - expand_direction = np.array([list(i < 0) for i in bbox_corner_dist]) - bbox_corners_expansion = (expand_direction == True) * expansion + ( + # Again expand bbox + if hasattr(self, "location"): + if self.location in ["left", "right"]: + bbox_corners_expansion = np.vstack([bbox_corner_dist[:, 0]*5, np.zeros_like( + bbox_corner_dist[:, 0])]).T + + elif self.location in ["top", "bottom"]: + bbox_corners_expansion = np.vstack(np.zeros_like(bbox_corner_dist[:, 0]), [bbox_corner_dist[:, 1]*5]).T + + # Case of 4 fragments expand uniformly. + else: + expansion = bbox_center / 2 + expand_direction = np.array([list(i < 0) for i in bbox_corner_dist]) + bbox_corners_expansion = (expand_direction == True) * expansion + ( expand_direction == False - ) * -expansion + ) * -expansion new_bbox_points = bbox_points + bbox_corners_expansion # Get closest point on contour as seen from enlarged bbox @@ -513,23 +602,34 @@ def save_landmark_points(self): cnt_fragments = [] if self.num_fragments == 2: - # Compute distance from contour corners to surrounding bbox - dist_cnt_corner_to_bbox = np.min( - distance.cdist(cnt_corners, new_bbox_points), axis=1 - ) - dist_cnt_corner_to_bbox_loop = np.hstack([dist_cnt_corner_to_bbox] * 2) + if hasattr(self, "location"): + if self.location == "left": + start_corner = cnt_corners_loop[ul_cnt_corner_idx + 1] + end_corner = cnt_corners_loop[ul_cnt_corner_idx + 2] - # Get the pair of contour corners with largest distance to bounding box, this should - # be the outer point of the contour. Stitch edge is then located opposite. - dist_per_cnt_corner_pair = [ - dist_cnt_corner_to_bbox_loop[i] ** 2 + dist_cnt_corner_to_bbox_loop[i + 1] ** 2 - for i in range(4) - ] - max_dist_corner_idx = np.argmax(dist_per_cnt_corner_pair) + elif self.location == "right": + start_corner = cnt_corners_loop[ul_cnt_corner_idx + 3] + end_corner = cnt_corners_loop[ul_cnt_corner_idx] + + else: + # Compute distance from contour corners to surrounding bbox + dist_cnt_corner_to_bbox = np.min( + distance.cdist(cnt_corners, new_bbox_points), axis=1 + ) + dist_cnt_corner_to_bbox_loop = np.hstack([dist_cnt_corner_to_bbox] * 2) + + # Get the pair of contour corners with largest distance to bounding box, this should + # be the outer point of the contour. Stitch edge is then located opposite. + dist_per_cnt_corner_pair = [ + dist_cnt_corner_to_bbox_loop[i] ** 2 + dist_cnt_corner_to_bbox_loop[i + 1] ** 2 + for i in range(4) + ] + max_dist_corner_idx = np.argmax(dist_per_cnt_corner_pair) + + # Get location and indices of these contour corners + start_corner = cnt_corners_loop[max_dist_corner_idx + 2] + end_corner = cnt_corners_loop[max_dist_corner_idx + 3] - # Get location and indices of these contour corners - start_corner = cnt_corners_loop[max_dist_corner_idx + 2] - end_corner = cnt_corners_loop[max_dist_corner_idx + 3] start_idx = np.argmax((cnt == start_corner).all(axis=1)) end_idx = np.argmax((cnt == end_corner).all(axis=1)) + 1 @@ -541,7 +641,7 @@ def save_landmark_points(self): # Sanity check plt.figure() - plt.imshow(lowres_mask) + plt.imshow(final_mask) plt.plot(cnt_fragments[0][:, 0], cnt_fragments[0][:, 1], c="r") plt.show() diff --git a/src/main.py b/src/main.py index 0d5d4a6..b6e33e9 100644 --- a/src/main.py +++ b/src/main.py @@ -42,9 +42,9 @@ def load_parameter_configuration(data_dir, save_dir, output_res): parameters["save_dir"] = save_dir parameters["patient_idx"] = data_dir.name parameters["output_res"] = output_res - parameters["fragment_names"] = [ + parameters["fragment_names"] = sorted([ i.name for i in data_dir.joinpath("raw_images").iterdir() if not i.is_dir() - ] + ]) parameters["n_fragments"] = len(parameters["fragment_names"]) parameters["resolution_scaling"] = [ i / parameters["resolutions"][0] for i in parameters["resolutions"] @@ -85,7 +85,7 @@ def collect_arguments(): # Parse arguments parser = argparse.ArgumentParser( - description="Stitch prostate histopathology images into a pseudo whole-mount image" + description="Stitch histopathology images into a pseudo whole-mount image" ) parser.add_argument( "--datadir", required=True, type=pathlib.Path, help="Path to the case to stitch" diff --git a/src/preprocessing_utils/prepare_data.py b/src/preprocessing_utils/prepare_data.py index fdfa6eb..494cae4 100644 --- a/src/preprocessing_utils/prepare_data.py +++ b/src/preprocessing_utils/prepare_data.py @@ -46,6 +46,10 @@ def load(self): self.new_dims = self.raw_image.getLevelDimensions(self.new_level) self.image = self.raw_image.getUCharPatch(0, 0, *self.new_dims, self.new_level) + ### EXPERIMENTAL - shift black background to white for correct otsu thresholding### + self.image[np.all(self.image<10, axis=2)] = 255 + ### \\\ EXPERIMENTAL ### + # Get downsampled mask with same dimensions as downsampled image if self.mask_provided: mask_dims = [ @@ -57,24 +61,35 @@ def load(self): self.mask = np.squeeze(self.mask) self.mask = ((self.mask / np.max(self.mask)) * 255).astype("uint8") - # Generic mask generation if mask is not provided. else: - # Process image and threshold for initial mask creation - img_hsv = cv2.cvtColor(self.image, cv2.COLOR_RGB2HSV) - img_med = cv2.medianBlur(img_hsv[:, :, 1], 7) - _, self.mask = cv2.threshold(img_med, 0, 255, cv2.THRESH_OTSU + cv2.THRESH_BINARY) + raise ValueError("PythoStitcher requires a tissue mask for stitching") + + return - # Close some holes - kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(10, 10)) - self.mask = cv2.morphologyEx( - src=self.mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=2 - ) + def get_otsu_mask(self): + """ + Method to get the mask using Otsu thresholding. This mask will be combined with + the tissue segmentation mask in order to filter out the fatty tissue. + """ + + # Convert to HSV space and perform Otsu thresholding + self.image_hsv = cv2.cvtColor(self.image, cv2.COLOR_RGB2HSV) + self.image_hsv = cv2.medianBlur(self.image_hsv[:, :, 1], 7) + _, self.otsu_mask = cv2.threshold(self.image_hsv, 0, 255, cv2.THRESH_OTSU + + cv2.THRESH_BINARY) + self.otsu_mask = (self.otsu_mask / np.max(self.otsu_mask)).astype("uint8") + + # Postprocess the mask a bit + kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(8, 8)) + self.otsu_mask = cv2.morphologyEx( + src=self.otsu_mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=3 + ) return - def postprocess(self): + def get_tissueseg_mask(self): """ - Function to postprocess the mask. This mainly consists + Function to postprocess both the regular and the mask. This mainly consists of getting the largest component and then cleaning up this mask. """ @@ -92,7 +107,7 @@ def postprocess(self): # Closing operation to close some holes on the mask border kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(10, 10)) - self.mask = cv2.morphologyEx(src=self.mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=3) + self.mask = cv2.morphologyEx(src=self.mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=2) # Temporarily enlarge mask for succesful floodfill later offset = 5 @@ -112,15 +127,45 @@ def postprocess(self): assert np.sum(self.mask) > 0, "floodfilled mask is empty" - # Perhaps 1 iter of erosion - # self.mask = cv2.morphologyEx(src=self.mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=3) + return + + def combine_masks(self): + """ + Method to combine the tissue mask and the Otsu mask for fat filtering. + """ + + # Combine + self.final_mask = self.otsu_mask * self.mask + + # Postprocess similar to tissue segmentation mask. Get largest cc and floodfill. + num_labels, labeled_im, stats, _ = cv2.connectedComponentsWithStats( + self.final_mask, connectivity=8 + ) + assert num_labels > 1, "mask is empty" + largest_cc_label = np.argmax(stats[1:, -1]) + 1 + self.final_mask = ((labeled_im == largest_cc_label) * 255).astype("uint8") + + # Flood fill + offset = 5 + self.final_mask = np.pad( + self.final_mask, + [[offset, offset], [offset, offset]], + mode="constant", + constant_values=0 + ) + + seedpoint = (0, 0) + floodfill_mask = np.zeros((self.final_mask.shape[0] + 2, self.final_mask.shape[1] + 2)).astype("uint8") + _, _, self.final_mask, _ = cv2.floodFill(self.final_mask, floodfill_mask, seedpoint, 255) + self.final_mask = self.final_mask[1 + offset:-1 - offset, 1 + offset:-1 - offset] + self.final_mask = 1 - self.final_mask # Crop to nonzero pixels for efficient saving - r, c = np.nonzero(self.mask) - self.mask = self.mask[np.min(r) : np.max(r), np.min(c) : np.max(c)] + r, c = np.nonzero(self.final_mask) + self.final_mask = self.final_mask[np.min(r) : np.max(r), np.min(c) : np.max(c)] self.image = self.image[np.min(r) : np.max(r), np.min(c) : np.max(c)] self.image = self.image.astype("uint8") - self.mask = (self.mask * 255).astype("uint8") + self.final_mask = (self.final_mask * 255).astype("uint8") return @@ -143,7 +188,7 @@ def save(self): mask_savedir.mkdir() mask_savefile = mask_savedir.joinpath(f"fragment{self.count}.png") - cv2.imwrite(str(mask_savefile), self.mask) + cv2.imwrite(str(mask_savefile), self.final_mask) return @@ -180,7 +225,9 @@ def prepare_data(parameters): count=c, ) data_processor.load() - data_processor.postprocess() + data_processor.get_otsu_mask() + data_processor.get_tissueseg_mask() + data_processor.combine_masks() data_processor.save() parameters["log"].log(parameters["my_level"], " > finished!\n") diff --git a/src/pythostitcher_utils/fragment_class.py b/src/pythostitcher_utils/fragment_class.py index c5b9ab0..89078c4 100644 --- a/src/pythostitcher_utils/fragment_class.py +++ b/src/pythostitcher_utils/fragment_class.py @@ -307,6 +307,32 @@ def get_bbox_corners(self, image): self.bbox = cv2.minAreaRect(self.cnt) self.bbox_corners = cv2.boxPoints(self.bbox) + bbox_center = np.mean(self.bbox_corners, axis=0) + bbox_corner_dist = bbox_center - self.bbox_corners + # Again expand bbox + if self.final_orientation in ["left", "right"]: + # bbox_corners_expansion = -np.vstack([bbox_corner_dist[:, 0] * 2, np.zeros_like( + # bbox_corner_dist[:, 0])]).T + bbox_corners_expansion = -np.vstack( + [bbox_corner_dist[:, 0] * 2, bbox_corner_dist[:, 1] * 0.5] + ) + bbox_corners_expansion = bbox_corners_expansion.astype("int").T + + elif self.final_orientation in ["top", "bottom"]: + bbox_corners_expansion = -np.vstack( + [bbox_corner_dist[:, 0] * 0.5, [bbox_corner_dist[:, 1] * 2]] + ) + bbox_corners_expansion = bbox_corners_expansion.astype("int").T + + # Case of 4 fragments expand uniformly. + else: + expansion = bbox_center / 2 + expand_direction = np.array([list(i < 0) for i in bbox_corner_dist]) + bbox_corners_expansion = (expand_direction == True) * expansion + ( + expand_direction == False + ) * -expansion + self.new_bbox_corners = self.bbox_corners + bbox_corners_expansion + # Get list of x-y values of contour x_points = [i[0] for i in self.cnt] y_points = [i[1] for i in self.cnt] @@ -322,7 +348,7 @@ def get_bbox_corners(self, image): ### other corners are named in clockwise direction. # Get distance from each corner to the mask - for corner in self.bbox_corners: + for corner in self.new_bbox_corners: dist_x = [np.abs(corner[0] - x_point) for x_point in x_points] dist_y = [np.abs(corner[1] - y_point) for y_point in y_points] dist = [np.sqrt(x ** 2 + y ** 2) for x, y in zip(dist_x, dist_y)] @@ -339,18 +365,18 @@ def get_bbox_corners(self, image): # Define stitching edge based on two points with largest/smallest x/y # coordinates. if self.final_orientation == "top": - corner_idxs = np.argsort(self.bbox_corners[:, 1])[-2:] + corner_idxs = np.argsort(self.new_bbox_corners[:, 1])[-2:] elif self.final_orientation == "bottom": - corner_idxs = np.argsort(self.bbox_corners[:, 1])[:2] + corner_idxs = np.argsort(self.new_bbox_corners[:, 1])[:2] elif self.final_orientation == "left": - corner_idxs = np.argsort(self.bbox_corners[:, 0])[-2:] + corner_idxs = np.argsort(self.new_bbox_corners[:, 0])[-2:] elif self.final_orientation == "right": - corner_idxs = np.argsort(self.bbox_corners[:, 0])[:2] + corner_idxs = np.argsort(self.new_bbox_corners[:, 0])[:2] - self.bbox_corner_a = self.bbox_corners[corner_idxs[0]] + self.bbox_corner_a = self.new_bbox_corners[corner_idxs[0]] self.mask_corner_a = mask_corners[corner_idxs[0]] self.mask_corner_a_idx = mask_corners_idxs[corner_idxs[0]] - self.bbox_corner_b = self.bbox_corners[corner_idxs[1]] + self.bbox_corner_b = self.new_bbox_corners[corner_idxs[1]] self.mask_corner_b = mask_corners[corner_idxs[1]] self.mask_corner_b_idx = mask_corners_idxs[corner_idxs[1]] @@ -750,7 +776,9 @@ def get_tformed_images(self, tform): self.tform_image = cv2.warpAffine( src=self.gray_image_original, M=rot_mat, dsize=output_shape[::-1] ) - self.mask = cv2.warpAffine(src=self.mask_original, M=rot_mat, dsize=output_shape[::-1]) + self.mask = cv2.warpAffine( + src=self.mask_original, M=rot_mat, dsize=output_shape[::-1] + ) # Save image center after transformation. This will be needed for the cost # function later on. @@ -800,6 +828,7 @@ def compute_edges(self): ) edge_ab = option_a if len(option_a) < len(option_b) else option_b + edge_ab = np.array(edge_ab) edge_ad = None # With 4 fragments we have to take the other cornerpoints into account. diff --git a/src/pythostitcher_utils/full_resolution.py b/src/pythostitcher_utils/full_resolution.py index 6593a37..5392298 100644 --- a/src/pythostitcher_utils/full_resolution.py +++ b/src/pythostitcher_utils/full_resolution.py @@ -7,6 +7,7 @@ import copy import matplotlib.pyplot as plt import logging +import math from .get_resname import get_resname from .fuse_images_highres import fuse_images_highres @@ -145,55 +146,51 @@ def get_scaling(self): return - def process_mask(self): + def get_tissue_seg_mask(self): """ - Process the mask to a full resolution version + Get the mask from the tissue segmentation algorithm and postprocess. """ - # First process the coordinates - self.line_a = (self.line_a / self.scaling_coords2outputres).astype("int") - self.line_b = (self.line_b / self.scaling_coords2outputres).astype("int") - # Get mask which is closest to 2k image. This is an empirical trade-off # between feasible image processing with opencv and mask resolution best_mask_output_dims = 2000 all_mask_dims = [ self.raw_mask.getLevelDimensions(i) for i in range(self.raw_mask.getNumberOfLevels()) ] - mask_ds_level = np.argmin([(i[0] - best_mask_output_dims) ** 2 for i in all_mask_dims]) + self.mask_ds_level = np.argmin([(i[0] - best_mask_output_dims) ** 2 for i in all_mask_dims]) - original_mask = self.raw_mask.getUCharPatch( + self.tissueseg_mask = self.raw_mask.getUCharPatch( startX=0, startY=0, - width=int(all_mask_dims[mask_ds_level][0]), - height=int(all_mask_dims[mask_ds_level][1]), - level=int(mask_ds_level), + width=int(all_mask_dims[self.mask_ds_level][0]), + height=int(all_mask_dims[self.mask_ds_level][1]), + level=int(self.mask_ds_level), ) # Convert mask for opencv processing - original_mask = original_mask / np.max(original_mask) - original_mask = (original_mask * 255).astype("uint8") - self.scaling_mask2outputres = self.outputres_image_dims[0] / original_mask.shape[1] + self.tissueseg_mask = self.tissueseg_mask / np.max(self.tissueseg_mask) + self.tissueseg_mask = (self.tissueseg_mask * 255).astype("uint8") + self.scaling_mask2outputres = self.outputres_image_dims[0] / self.tissueseg_mask.shape[1] # Get information on all connected components in the mask - num_labels, original_mask, stats, _ = cv2.connectedComponentsWithStats( - original_mask, connectivity=8 + num_labels, self.tissueseg_mask, stats, _ = cv2.connectedComponentsWithStats( + self.tissueseg_mask, connectivity=8 ) # Extract largest connected component largest_cc_label = np.argmax(stats[1:, -1]) + 1 - original_mask = ((original_mask == largest_cc_label) * 255).astype("uint8") + self.tissueseg_mask = ((self.tissueseg_mask == largest_cc_label) * 255).astype("uint8") # Some morphological operations for cleaning up edges kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(20, 20)) - original_mask = cv2.morphologyEx( - src=original_mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=2 + self.tissueseg_mask = cv2.morphologyEx( + src=self.tissueseg_mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=2 ) # Slightly enlarge temporary - temp_pad = int(0.05 * original_mask.shape[0]) - original_mask = np.pad( - original_mask, + temp_pad = int(0.05 * self.tissueseg_mask.shape[0]) + self.tissueseg_mask = np.pad( + self.tissueseg_mask, [[temp_pad, temp_pad], [temp_pad, temp_pad]], mode="constant", constant_values=0, @@ -201,14 +198,132 @@ def process_mask(self): # Flood fill to remove holes inside mask seedpoint = (0, 0) - floodfill_mask = np.zeros((original_mask.shape[0] + 2, original_mask.shape[1] + 2)).astype( + floodfill_mask = np.zeros((self.tissueseg_mask.shape[0] + 2, self.tissueseg_mask.shape[1] + + 2)).astype( "uint8" ) - _, _, original_mask, _ = cv2.floodFill(original_mask, floodfill_mask, seedpoint, 255) - original_mask = ( - 1 - original_mask[temp_pad + 1 : -(temp_pad + 1), temp_pad + 1 : -(temp_pad + 1)] + _, _, self.tissueseg_mask, _ = cv2.floodFill(self.tissueseg_mask, floodfill_mask, + seedpoint, + 255) + self.tissueseg_mask = self.tissueseg_mask[ + temp_pad + 1: -(temp_pad + 1), temp_pad + 1: -(temp_pad + 1) + ] + self.tissueseg_mask = 1 - self.tissueseg_mask + + return + + def get_otsu_mask(self): + """ + Get mask based on Otsu thresholding. + """ + + # Get image of same size as tissue segmentation mask + + ### EXPERIMENTAL --- get the image level based on the mask level ### + # New code + im_vs_mask_ds = self.raw_image.getLevelDimensions(0)[0]/self.raw_mask.getLevelDimensions( + 0)[0] + image_ds_level = int(self.mask_ds_level + + int(math.log2(im_vs_mask_ds))) + + # Previous code + # image_ds_level = int(self.mask_ds_level + int(math.log2(self.scaling_coords2outputres))) + ### \\\ EXPERIMENTAL ### + + image_ds_dims = self.raw_image.getLevelDimensions(image_ds_level) + self.otsu_image = self.raw_image.getUCharPatch( + startX=0, + startY=0, + width=int(image_ds_dims[0]), + height=int(image_ds_dims[1]), + level=int(image_ds_level), ) + ### EXPERIMENTAL - shift black background to white for correct otsu thresholding### + self.otsu_image[np.all(self.otsu_image<10, axis=2)] = 255 + ### \\\ EXPERIMENTAL ### + + image_hsv = cv2.cvtColor(self.otsu_image, cv2.COLOR_RGB2HSV) + image_hsv = cv2.medianBlur(image_hsv[:, :, 1], 7) + _, self.otsu_mask = cv2.threshold(image_hsv, 0, 255, cv2.THRESH_OTSU + + cv2.THRESH_BINARY) + self.otsu_mask = (self.otsu_mask / np.max(self.otsu_mask)).astype("uint8") + + # Postprocess the mask a bit + kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(15, 15)) + self.otsu_mask = cv2.morphologyEx( + src=self.otsu_mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=3 + ) + + return + + + def combine_masks(self): + """ + Combine Otsu mask and tissue segmentation mask, similar in preprocessing scripts. + """ + + # First process the coordinates + self.line_a = (self.line_a / self.scaling_coords2outputres).astype("int") + self.line_b = (self.line_b / self.scaling_coords2outputres).astype("int") + + # Combine masks + self.final_mask = self.otsu_mask * self.tissueseg_mask + + # Postprocess similar to tissue segmentation mask. Get largest cc and floodfill. + num_labels, labeled_im, stats, _ = cv2.connectedComponentsWithStats( + self.final_mask, connectivity=8 + ) + assert num_labels > 1, "mask is empty" + largest_cc_label = np.argmax(stats[1:, -1]) + 1 + self.final_mask = ((labeled_im == largest_cc_label) * 255).astype("uint8") + + # Flood fill + offset = 5 + self.final_mask = np.pad( + self.final_mask, + [[offset, offset], [offset, offset]], + mode="constant", + constant_values=0 + ) + + seedpoint = (0, 0) + floodfill_mask = np.zeros( + (self.final_mask.shape[0] + 2, self.final_mask.shape[1] + 2)).astype("uint8") + _, _, self.final_mask, _ = cv2.floodFill(self.final_mask, floodfill_mask, seedpoint, 255) + self.final_mask = self.final_mask[1 + offset:-1 - offset, 1 + offset:-1 - offset] + self.final_mask = 1 - self.final_mask + + # Crop to nonzero pixels for efficient saving + self.r_idx, self.c_idx = np.nonzero(self.final_mask) + self.final_mask = self.final_mask[np.min(self.r_idx): np.max(self.r_idx), + np.min(self.c_idx): np.max(self.c_idx)] + + # Convert to pyvips array + height, width = self.final_mask.shape + bands = 1 + dformat = "uchar" + self.outputres_mask = pyvips.Image.new_from_memory( + self.final_mask.ravel(), width, height, bands, dformat + ) + + self.outputres_mask = self.outputres_mask.resize(self.scaling_mask2outputres) + + if self.rot_k in range(1, 4): + self.outputres_mask = self.outputres_mask.rotate(self.rot_k * 90) + + # Pad image with zeros + self.outputres_mask = self.outputres_mask.gravity( + "centre", self.target_dims[1], self.target_dims[0] + ) + + return + + def process_mask(self): + """ + Process the mask to a full resolution version + """ + # Get nonzero indices and crop self.r_idx, self.c_idx = np.nonzero(original_mask) original_mask = original_mask[ @@ -381,7 +496,10 @@ def generate_full_res(parameters, log): # initial stitch f.load_images() f.get_scaling() - f.process_mask() + # f.process_mask() + f.get_tissue_seg_mask() + f.get_otsu_mask() + f.combine_masks() f.process_image() ### DEBUGGING ### @@ -420,10 +538,13 @@ def generate_full_res(parameters, log): # Perform blending in areas of overlap log.log(parameters["my_level"], "Blending areas of overlap") - result_image, comp_time = perform_blending( + start = time.time() + result_image = perform_blending( result_image, result_mask, full_res_fragments, log, parameters ) - log.log(parameters["my_level"], f" > finished in {comp_time} mins!\n") + log.log( + parameters["my_level"], f" > finished in {int(np.ceil((time.time()-start)/60))} mins!\n" + ) # Remove temporary mask parameters["sol_save_dir"].joinpath("highres", "temp_mask.tif").unlink() diff --git a/src/pythostitcher_utils/gradient_blending.py b/src/pythostitcher_utils/gradient_blending.py index b99baca..ef9b6a8 100644 --- a/src/pythostitcher_utils/gradient_blending.py +++ b/src/pythostitcher_utils/gradient_blending.py @@ -53,8 +53,6 @@ def perform_blending(result_image, result_mask, full_res_fragments, log, paramet # Param for saving blending result n_valid = 0 - start = time.time() - # Find blending points per contour for c, mask_cnt in enumerate(mask_cnts): @@ -256,9 +254,6 @@ def perform_blending(result_image, result_mask, full_res_fragments, log, paramet else: continue - - comp_time = int(np.ceil((time.time() - start) / 60)) - # Get the correct orientation of the prostate result_image = correct_orientation( mask = mask_ds, @@ -267,7 +262,7 @@ def perform_blending(result_image, result_mask, full_res_fragments, log, paramet debug_visualization = True ) - return result_image, comp_time + return result_image def correct_orientation(mask, result_image, parameters, debug_visualization):