diff --git a/src/assembly_utils/jigsawnet.py b/src/assembly_utils/jigsawnet.py index 091682e..2f8a37b 100644 --- a/src/assembly_utils/jigsawnet.py +++ b/src/assembly_utils/jigsawnet.py @@ -214,7 +214,7 @@ def SingleTest(checkpoint_root, K, net, is_training=False): sess_init_op = tf.group(tf.compat.v1.global_variables_initializer(), tf.compat.v1.local_variables_initializer()) sess.run(sess_init_op) - saver.restore(sess, tf.train.latest_checkpoint(f"./{check_point}")) + saver.restore(sess, tf.train.latest_checkpoint(check_point)) sessions.append(sess) # Inference on JigsawNet. Note how this is only performed with batch_size=1. Perhaps diff --git a/src/main.py b/src/main.py index a24919f..72f9f41 100644 --- a/src/main.py +++ b/src/main.py @@ -1,37 +1,43 @@ -import os -os.environ["VIPS_CONCURRENCY"] = "20" import argparse -import logging import json +import logging +import os import pathlib -from preprocessing_utils.prepare_data import prepare_data from assembly_utils.detect_configuration import detect_configuration -from pythostitcher_utils.preprocess import preprocess -from pythostitcher_utils.optimize_stitch import optimize_stitch +from preprocessing_utils.prepare_data import prepare_data from pythostitcher_utils.fragment_class import Fragment -from pythostitcher_utils.get_resname import get_resname from pythostitcher_utils.full_resolution import generate_full_res +from pythostitcher_utils.get_resname import get_resname +from pythostitcher_utils.optimize_stitch import optimize_stitch +from pythostitcher_utils.preprocess import preprocess + +os.environ["VIPS_CONCURRENCY"] = "20" -def load_parameter_configuration(data_dir, save_dir, patient_idx): +def load_parameter_configuration(data_dir, save_dir, output_res): """ Convenience function to load all the PythoStitcher parameters and pack them up in a dictionary for later use. """ # Verify its existence - config_file = pathlib.Path("./config/parameter_config.json") + config_file = pathlib.Path().absolute().parent.joinpath("config/parameter_config.json") assert config_file.exists(), "parameter config file not found" # Load main parameter config with open(config_file) as f: parameters = json.load(f) + # Convert model weight paths to absolute paths + parameters["weights_fragment_classifier"] = pathlib.Path().absolute().parent.joinpath(parameters["weights_fragment_classifier"]) + parameters["weights_jigsawnet"] = pathlib.Path().absolute().parent.joinpath(parameters["weights_jigsawnet"]) + # Insert parsed arguments parameters["data_dir"] = data_dir parameters["save_dir"] = save_dir - parameters["patient_idx"] = patient_idx + parameters["patient_idx"] = data_dir.name + parameters["output_res"] = output_res parameters["fragment_names"] = [i.name for i in data_dir.joinpath("raw_images").iterdir() if i.is_dir()] parameters["n_fragments"] = len(parameters["fragment_names"]) parameters["resolution_scaling"] = [i/parameters["resolutions"][0] for i in parameters["resolutions"]] @@ -69,32 +75,43 @@ def collect_arguments(): description="Stitch prostate histopathology images into a pseudo whole-mount image" ) parser.add_argument( - "--datadir", required=True, help="General data directory with all patients" + "--datadir", + required=True, + type=pathlib.Path, + help="Path to the case to stitch" ) parser.add_argument( - "--savedir", required=True, help="Directory to save the results" + "--savedir", + required=True, + type=pathlib.Path, + help="Directory to save the results" ) parser.add_argument( - "--patient", required=True, help="Patient to process" + "--resolution", + required=True, + default=0.25, + type=float, + help="Output resolution (µm/pixel) of the reconstructed image. Should be roughly " + "in range of 0.25-20." ) args = parser.parse_args() # Extract arguments - patient_idx = args.patient - data_dir = pathlib.Path(args.datadir).joinpath(patient_idx) - save_dir = pathlib.Path(args.savedir).joinpath(patient_idx) + data_dir = pathlib.Path(args.datadir) + save_dir = pathlib.Path(args.savedir).joinpath(data_dir.name) + resolution = args.resolution - assert pathlib.Path(args.datadir).is_dir(), "provided data directory doesn't exist" - assert data_dir.is_dir(), "provided patient could not be found in data directory" + assert data_dir.is_dir(), "provided patient directory doesn't exist" assert data_dir.joinpath("raw_images").is_dir(), "patient has no 'raw_images' directory" assert len(list(data_dir.joinpath("raw_images").iterdir())) > 0, "no images found in 'raw_images' directory" + assert resolution > 0, "output resolution cannot be negative" print(f"\nRunning job with following parameters:" - f"\n - Data dir: {args.datadir}" + f"\n - Data dir: {data_dir}" f"\n - Save dir: {save_dir}" - f"\n - Patient: {patient_idx}") + f"\n - Output resolution: {resolution} µm/pixel\n") - return data_dir, save_dir, patient_idx + return data_dir, save_dir, resolution def main(): @@ -115,7 +132,7 @@ def main(): /data /{Patient_identifier} /raw_images - {fragment_name}.mrxs + {fragment_name}.mrxs§ {fragment_name}.mrxs /raw_masks {fragment_name}.tif @@ -126,8 +143,8 @@ def main(): ### ARGUMENT CONFIGURATION ### # Collect arguments - data_dir, save_dir, patient_idx = collect_arguments() - parameters = load_parameter_configuration(data_dir, save_dir, patient_idx) + data_dir, save_dir, output_res = collect_arguments() + parameters = load_parameter_configuration(data_dir, save_dir, output_res) # Initiate logging file logfile = save_dir.joinpath("pythostitcher_log.log") @@ -144,9 +161,10 @@ def main(): log = logging.getLogger("pythostitcher") parameters["log"] = log - parameters["log"].log(parameters["my_level"], f"Running job with following parameters:" - f"\n - Data dir: {parameters['data_dir']}" - f"\n - Save dir: {parameters['save_dir']}\n") + parameters["log"].log(parameters["my_level"], f"Running job with following parameters:") + parameters["log"].log(parameters["my_level"], f" - Data dir: {parameters['data_dir']}") + parameters["log"].log(parameters["my_level"], f" - Save dir: {parameters['save_dir']}") + parameters["log"].log(parameters["my_level"], f" - Output resolution: {parameters['output_res']}\n") if not data_dir.joinpath("raw_masks").is_dir(): parameters["log"].log( @@ -155,7 +173,8 @@ def main(): f"PythoStitcher with pregenerated tissuemasks, please put these files in " f"[{data_dir.joinpath('raw_masks')}]. If no tissuemasks are supplied, " f"PythoStitcher will use a generic tissue segmentation which may not perform " - f"as well as the AI-generated masks.") + f"as well as the AI-generated masks. In addition, PythoStitcher will not " + f"be able to generate the full resolution end result.") ### MAIN PYTHOSTITCHER #s## # Preprocess data @@ -182,7 +201,7 @@ def main(): fragments = [] for im_path, fragment_name in sol.items(): fragments.append(Fragment( - im_path = im_path, + im_path=im_path, fragment_name=fragment_name, kwargs=parameters) ) @@ -198,7 +217,7 @@ def main(): parameters["log"].log( parameters["my_level"], - f"Succesfully stitched solution {count_sol}" + f"### Succesfully stitched solution {count_sol} ###\n" ) return diff --git a/src/pythostitcher_utils/full_resolution.py b/src/pythostitcher_utils/full_resolution.py index 64695e9..4dc6ac7 100644 --- a/src/pythostitcher_utils/full_resolution.py +++ b/src/pythostitcher_utils/full_resolution.py @@ -1,6 +1,5 @@ import numpy as np import os -os.environ["VIPS_CONCURRENCY"] = "20" import pyvips import cv2 import multiresolutionimageinterface as mir @@ -10,6 +9,9 @@ from .get_resname import get_resname from .fuse_images_highres import fuse_images_highres +from .gradient_blending import perform_blending + +os.environ["VIPS_CONCURRENCY"] = "20" class FullResImage: @@ -22,6 +24,7 @@ class FullResImage: def __init__(self, parameters, idx): self.idx = idx + self.output_res = parameters["output_res"] self.raw_image_name = parameters["raw_image_names"][self.idx] self.raw_mask_name = parameters["raw_mask_names"][self.idx] self.save_dir = parameters["sol_save_dir"] @@ -51,18 +54,11 @@ def load_images(self): resolution. """ - # Full resolution image - self.raw_image = pyvips.Image.new_from_file(str(self.raw_image_path)) - self.raw_image_dims = ( - self.raw_image.get("width"), - self.raw_image.get("height"), - ) - - # Mask obtained with tissue segmentation - self.raw_mask = pyvips.Image.new_from_file(str(self.raw_mask_path)) - self.raw_mask_dims = (self.raw_mask.get("width"), self.raw_mask.get("height")) + # Get full res image, mask and pythostitcher mask + self.opener = mir.MultiResolutionImageReader() + self.raw_image = self.opener.open(str(self.raw_image_path)) + self.raw_mask = self.opener.open(str(self.raw_mask_path)) - # Mask obtained through the preprocessing in pythostitcher self.ps_mask = cv2.imread(str(self.ps_mask_path)) self.ps_mask = cv2.cvtColor(self.ps_mask, cv2.COLOR_BGR2GRAY) @@ -70,26 +66,52 @@ def load_images(self): def get_scaling(self): """ - Get scaling factors to go from PythoStitcher resolution to full resolution + Get scaling factors to go from PythoStitcher resolution to desired output resolution """ - # Get scaling factor between PythoStitcher mask, raw mask and raw image - self.scaling_mask2fullres = int(self.raw_image_dims[0] / self.raw_mask_dims[0]) - try: - self.scaling_ps2fullres = int( - int(self.raw_image.get(f"openslide.level[{self.ps_level}].downsample")) - / self.last_res + # Get full resolution dims + self.raw_image_dims = self.raw_image.getLevelDimensions(0) + self.raw_mask_dims = self.raw_mask.getLevelDimensions(0) + + # Obtain the resolution (µm/pixel) for each level + n_levels = self.raw_image.getNumberOfLevels() + ds_per_level = [self.raw_image.getLevelDownsample(i) for i in range(n_levels)] + res_per_level = [self.raw_image.getSpacing()[0]*scale for i, scale in zip(range(n_levels), ds_per_level)] + + # Get the optimal level based on the desired output resolution + self.output_level = np.argmin([(i - self.output_res)**2 for i in res_per_level]) + + assert self.output_level <= self.ps_level, \ + f"Resolution level of the output image must be lower than the PythoStitcher " \ + f"resolution level. Provided utput level is {self.output_level}, while " \ + f"PythoStitcher level is {self.ps_level}. Please increase the output resolution." + + # Get image on this optimal output level + if self.raw_image_path.suffix == ".mrxs": + self.outputres_image = pyvips.Image.new_from_file( + str(self.raw_image_path), + level=self.output_level ) - except: - opener = mir.MultiResolutionImageReader() - _raw_image = opener.open(str(self.raw_image_path)) - self.scaling_ps2fullres = ( - int(_raw_image.getLevelDownsample(self.ps_level)) / self.last_res + elif self.raw_image_path.suffix == ".tif": + self.outputres_image = pyvips.Image.new_from_file( + str(self.raw_image_path), + page=self.output_level ) + else: + raise ValueError("currently we only support mrxs and tif files") + + # Get new image dims + self.outputres_image_dims = ( + self.outputres_image.get("width"), + self.outputres_image.get("height"), + ) + + # Get scaling factor raw mask and final output resolution + self.scaling_ps2outputres = 2**(self.ps_level - self.output_level) / self.last_res - # Dimension of final stitchting result + # Dimension of final stitching result self.target_dims = [ - int(i * self.scaling_ps2fullres) for i in self.ps_mask.shape + int(i * self.scaling_ps2outputres) for i in self.ps_mask.shape ] # Get the optimal transformation obtained with the genetic algorithm @@ -100,14 +122,14 @@ def get_scaling(self): # Upsample it to use it for the final image self.highres_tform = [ - int(self.ps_tform[self.orientation][0] * self.scaling_ps2fullres), - int(self.ps_tform[self.orientation][1] * self.scaling_ps2fullres), + int(self.ps_tform[self.orientation][0] * self.scaling_ps2outputres), + int(self.ps_tform[self.orientation][1] * self.scaling_ps2outputres), np.round(self.ps_tform[self.orientation][2], 1), tuple( - [int(i * self.scaling_ps2fullres) for i in self.ps_tform[self.orientation][3]] + [int(i * self.scaling_ps2outputres) for i in self.ps_tform[self.orientation][3]] ), tuple( - [int(i * self.scaling_ps2fullres) for i in self.ps_tform[self.orientation][4]] + [int(i * self.scaling_ps2outputres) for i in self.ps_tform[self.orientation][4]] ), ] @@ -118,32 +140,36 @@ def process_mask(self): Process the mask to a full resolution version """ - # Get high resolution mask (spacing 3.88x3.88) - opener = mir.MultiResolutionImageReader() - mask = opener.open(str(self.raw_mask_path)) - original_mask = mask.getUCharPatch( + # Get mask which is closest to 4k image. This is an empirical trade-off + # between feasible image processing with opencv and mask resolution + best_mask_output_dims = 4000 + all_mask_dims = [self.raw_mask.getLevelDimensions(i) for i in range(self.raw_mask.getNumberOfLevels())] + mask_ds_level = np.argmin([(i[0]-best_mask_output_dims)**2 for i in all_mask_dims]) + + original_mask = self.raw_mask.getUCharPatch( startX=0, startY=0, - width=self.raw_mask_dims[0], - height=self.raw_mask_dims[1], - level=0, + width=int(all_mask_dims[mask_ds_level][0]), + height=int(all_mask_dims[mask_ds_level][1]), + level=int(mask_ds_level), ) # Convert mask for opencv processing original_mask = original_mask / np.max(original_mask) original_mask = (original_mask * 255).astype("uint8") + self.scaling_mask2outputres = self.outputres_image_dims[0] / original_mask.shape[1] # Get information on all connected components in the mask - num_labels, labeled_mask, stats, _ = cv2.connectedComponentsWithStats( + num_labels, original_mask, stats, _ = cv2.connectedComponentsWithStats( original_mask, connectivity=4 ) # Extract largest connected component largest_cc_label = np.argmax(stats[1:, -1]) + 1 - original_mask = ((labeled_mask == largest_cc_label) * 255).astype("uint8") + original_mask = ((original_mask == largest_cc_label) * 255).astype("uint8") # Some morphological operations for cleaning up edges - kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(40, 40)) + kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(20, 20)) original_mask = cv2.morphologyEx( src=original_mask, op=cv2.MORPH_CLOSE, kernel=kernel, iterations=2 ) @@ -180,16 +206,16 @@ def process_mask(self): height, width = original_mask.shape bands = 1 dformat = "uchar" - self.fullres_mask = pyvips.Image.new_from_memory( + self.outputres_mask = pyvips.Image.new_from_memory( original_mask.ravel(), width, height, bands, dformat ) - self.fullres_mask = self.fullres_mask.resize(self.scaling_mask2fullres) + self.outputres_mask = self.outputres_mask.resize(self.scaling_mask2outputres) - self.fullres_mask = self.fullres_mask.rotate(-self.rot_k*90) + self.outputres_mask = self.outputres_mask.rotate(-self.rot_k*90) # Pad image with zeros - self.fullres_mask = self.fullres_mask.gravity( + self.outputres_mask = self.outputres_mask.gravity( "centre", self.target_dims[1], self.target_dims[0] ) @@ -201,29 +227,29 @@ def process_image(self): """ # Dispose of alpha channel if applicable - if self.raw_image.hasalpha(): - self.raw_image = self.raw_image.flatten() + if self.outputres_image.hasalpha(): + self.outputres_image = self.outputres_image.flatten() # Get cropping indices rmin, rmax = ( - int(self.scaling_mask2fullres * np.min(self.r_idx)), - int(self.scaling_mask2fullres * np.max(self.r_idx)), + int(self.scaling_mask2outputres * np.min(self.r_idx)), + int(self.scaling_mask2outputres * np.max(self.r_idx)), ) cmin, cmax = ( - int(self.scaling_mask2fullres * np.min(self.c_idx)), - int(self.scaling_mask2fullres * np.max(self.c_idx)), + int(self.scaling_mask2outputres * np.min(self.c_idx)), + int(self.scaling_mask2outputres * np.max(self.c_idx)), ) width = cmax - cmin height = rmax - rmin # Crop image - self.fullres_image = self.raw_image.crop(cmin, rmin, width, height) + self.outputres_image = self.outputres_image.crop(cmin, rmin, width, height) # Rotate image - self.fullres_image = self.fullres_image.rotate(-self.rot_k*90) + self.outputres_image = self.outputres_image.rotate(-self.rot_k*90) # Pad image with zeros - self.fullres_image = self.fullres_image.gravity( + self.outputres_image = self.outputres_image.gravity( "centre", self.target_dims[1], self.target_dims[0] ) @@ -235,7 +261,7 @@ def process_image(self): rotmat[1, 2] += self.highres_tform[1] # Apply affine transformation - self.fullres_image_rot = self.fullres_image.affine( + self.outputres_image = self.outputres_image.affine( (rotmat[0, 0], rotmat[0, 1], rotmat[1, 0], rotmat[1, 1]), interpolate=pyvips.Interpolate.new("nearest"), odx=rotmat[0, 2], @@ -243,7 +269,7 @@ def process_image(self): oarea=[0, 0, self.highres_tform[4][1], self.highres_tform[4][0]], ) - self.fullres_mask_rot = self.fullres_mask.affine( + self.outputres_mask = self.outputres_mask.affine( (rotmat[0, 0], rotmat[0, 1], rotmat[1, 0], rotmat[1, 1]), interpolate=pyvips.Interpolate.new("nearest"), odx=rotmat[0, 2], @@ -251,45 +277,21 @@ def process_image(self): oarea=[0, 0, self.highres_tform[4][1], self.highres_tform[4][0]], ) - if not self.fullres_image_rot.format == "uchar": - self.fullres_image_rot = self.fullres_image_rot.cast("uchar", shift=False) + if not self.outputres_image.format == "uchar": + self.outputres_image = self.outputres_image.cast("uchar", shift=False) - if not self.fullres_mask_rot.format == "uchar": - self.fullres_mask_rot = self.fullres_mask_rot.cast("uchar", shift=False) + if not self.outputres_mask.format == "uchar": + self.outputres_mask = self.outputres_mask.cast("uchar", shift=False) # Apply mask to images - self.final_image = self.fullres_image_rot.multiply(self.fullres_mask_rot) + self.final_image = self.outputres_image.multiply(self.outputres_mask) if not self.final_image.format == "uchar": self.final_image = self.final_image.cast("uchar", shift=False) return -def mask_eval_handler(result_mask, progress): - """ - Function to display progress of pyvips operation. - - Inputs - - Pyvips instance - - Progress instance - - Outputs - - Progress prompt in log - """ - - global savelog_mask - - percent = int(np.round(progress.percent)) - if percent % 10 == 0 and percent not in savelog_mask: - handler_log.log(35, f" - progress {progress.percent}%") - savelog_mask.append(percent) - - time.sleep(1) - - return - - -def image_eval_handler(result, progress): +def image_eval_handler(result_image, progress): """ Function to display progress of pyvips operation. @@ -313,172 +315,6 @@ def image_eval_handler(result, progress): return -def perform_blending(result_image, result_mask, full_res_fragments, log, blend_dir): - """ - Function to blend areas of overlap. - - Inputs - - Full resolution image - - Full resolution mask - - All fragments - - Logging instance - - Directory to save blending results - - Output - - Full resolution blended image - """ - - # Get dimensions of image - width = result_mask.width - height = result_mask.height - - # Specify tile sizes - tilesize = 8192 - num_xtiles = int(np.ceil(width / tilesize)) - num_ytiles = int(np.ceil(height / tilesize)) - - start = time.time() - - # Loop over columns - for x in range(num_xtiles): - - log.log(35, f" - blending column {x+1}/{num_xtiles}") - - # Loop over rows - for y in range(num_ytiles): - - # To extract tiles via pyvips we need the starting X and Y and the tilesize. - # This tilesize will differ based on whether we can retrieve a full square - # tile, which is only not possible near the right and bottom edge. The tile - # in these cases will be smaller. - - # Case for lower right corner - if x == num_xtiles - 1 and y == num_ytiles - 1: - new_tilesize = [width - x * tilesize, height - y * tilesize] - # Case for right edge - elif x == num_xtiles - 1 and y < num_ytiles - 1: - new_tilesize = [width - x * tilesize, tilesize] - # Case for bottom edge - elif x < num_xtiles - 1 and y == num_ytiles - 1: - new_tilesize = [tilesize, height - y * tilesize] - # Regular cases - else: - new_tilesize = [tilesize, tilesize] - - # Only perform bending in case of overlap - tile_mask = result_mask.crop( - x * tilesize, y * tilesize, new_tilesize[0], new_tilesize[1] - ) - - if tile_mask.max() > 1: - - # Extract the corresponding image and mask for all fragments - images = dict() - masks = dict() - for f in full_res_fragments: - image_patch = f.final_image.crop( - x * tilesize, y * tilesize, new_tilesize[0], new_tilesize[1], - ) - image = np.ndarray( - buffer=image_patch.write_to_memory(), - dtype=np.uint8, - shape=[new_tilesize[1], new_tilesize[0], image_patch.bands], - ) - - mask_patch = f.fullres_mask_rot.crop( - x * tilesize, y * tilesize, new_tilesize[0], new_tilesize[1], - ) - mask = np.ndarray( - buffer=mask_patch.write_to_memory(), - dtype=np.uint8, - shape=[new_tilesize[1], new_tilesize[0]], - ) - - images[f.orientation] = image - masks[f.orientation] = mask - - # Perform the actual blending - blend, grad, overlap_fragments, valid = fuse_images_highres( - images, masks - ) - - if valid: - - # Get overlap contours for plotting - overlap = (~np.isnan(grad) * 255).astype("uint8") - overlap_cnts, _ = cv2.findContours( - overlap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE, - ) - - # Show and save blended result - plt.figure(figsize=(12, 10)) - plt.suptitle(f"Result for row {y} and col {x}", fontsize=24) - plt.subplot(231) - plt.title(f"Mask fragment '{overlap_fragments[0]}'", fontsize=20) - plt.imshow(masks[overlap_fragments[0]], cmap="gray") - plt.axis("off") - plt.clim([0, 1]) - plt.subplot(232) - plt.title(f"Mask fragment '{overlap_fragments[1]}'", fontsize=20) - plt.imshow(masks[overlap_fragments[1]], cmap="gray") - plt.axis("off") - plt.clim([0, 1]) - plt.subplot(233) - plt.title("Mask overlap + gradient", fontsize=20) - plt.imshow( - (masks[overlap_fragments[0]] + masks[overlap_fragments[1]])==2, - cmap="gray", - ) - plt.imshow(grad, cmap="jet", alpha=0.5) - plt.axis("off") - plt.colorbar(fraction=0.046, pad=0.04) - plt.subplot(234) - plt.title(f"Image fragment '{overlap_fragments[0]}'", fontsize=20) - plt.imshow(images[overlap_fragments[0]]) - for cnt in overlap_cnts: - cnt = np.squeeze(cnt) - plt.plot(cnt[:, 0], cnt[:, 1], c="r", linewidth=3) - plt.axis("off") - plt.subplot(235) - plt.title(f"Image fragment '{overlap_fragments[1]}'", fontsize=20) - plt.imshow(images[overlap_fragments[1]]) - for cnt in overlap_cnts: - cnt = np.squeeze(cnt) - plt.plot(cnt[:, 0], cnt[:, 1], c="r", linewidth=3) - plt.axis("off") - plt.subplot(236) - plt.title("Blend image", fontsize=20) - plt.imshow(blend) - for cnt in overlap_cnts: - cnt = np.squeeze(cnt) - plt.plot(cnt[:, 0], cnt[:, 1], c="r", linewidth=3) - plt.axis("off") - plt.tight_layout() - plt.savefig( - f"{blend_dir}/row{str(y).zfill(4)}_col{str(x).zfill(4)}.png" - ) - plt.close() - - # Insert blended image - h, w = blend.shape[:2] - bands = 3 - dformat = "uchar" - blend_image = pyvips.Image.new_from_memory( - blend.ravel(), w, h, bands, dformat - ) - - result_image = result_image.insert( - blend_image, x * new_tilesize[0], y * new_tilesize[1] - ) - - else: - continue - - comp_time = int((time.time() - start) / 60) - - return result_image, comp_time - - def generate_full_res(parameters, log): """ Function to generate the final full resolution stitching results. @@ -495,23 +331,25 @@ def generate_full_res(parameters, log): savelog_image, savelog_mask = [0], [0] handler_log = copy.deepcopy(log) + parameters["blend_dir"] = parameters["sol_save_dir"].joinpath("highres", "blend_summary") + # Initiate class for each fragment to handle full resolution image full_res_fragments = [FullResImage(parameters, idx) for idx in range(parameters["n_fragments"])] - blend_dir = parameters["sol_save_dir"].joinpath("highres", "blend_summary") - - log.log(parameters["my_level"], "Processing full resolution fragments") + log.log(parameters["my_level"], "Processing high resolution fragments") start = time.time() for f in full_res_fragments: - log.log(parameters["my_level"], f" - fragment '{f.raw_image_path.name}'") + log.log(parameters["my_level"], f" - '{f.raw_image_path.name}'") + # Transform each fragment such that all final images can just be added as an + # initial stitch f.load_images() f.get_scaling() f.process_mask() f.process_image() log.log( - parameters["my_level"], f" > finished in {int((time.time()-start)/60)} mins!\n" + parameters["my_level"], f" > finished in {int(np.ceil((time.time()-start)/60))} mins!\n" ) # Add all images as the default stitching method @@ -520,43 +358,34 @@ def generate_full_res(parameters, log): result_image = result_image.cast("uchar", shift=False) # Do the same for masks - result_mask = pyvips.Image.sum([f.fullres_mask_rot for f in full_res_fragments]) + result_mask = pyvips.Image.sum([f.outputres_mask for f in full_res_fragments]) if not result_mask.format == "uchar": result_mask = result_mask.cast("uchar", shift=False) - """ - # Save full resolution mask if needed - log.log(parameters["my_level"], "Saving full resolution mask") - start = time.time() - result_mask.set_progress(True) - result_mask.signal_connect("eval", mask_eval_handler) - result_mask.tiffsave( - f"{parameters['save_dir']}/highres/fullres_mask.tif", - tile=True, - compression="lzw", - bigtiff=True, - pyramid=True, - Q=80, - ) - log.log(parameters["my_level"], f" > finished in {int((time.time()-start)/60)} mins!\n") - """ + # Save temp .tif version of mask for later use in blending + parameters["tif_mask_path"] = str(parameters["sol_save_dir"].joinpath("highres", "temp_mask.tif")) + # result_mask.write_to_file( + # tif_mask_path, + # tile=True, + # compression="lzw", + # bigtiff=True, + # pyramid=True, + # ) # Perform blending in areas of overlap log.log(parameters["my_level"], "Blending areas of overlap") - result, comp_time = perform_blending( - result_image, result_mask, full_res_fragments, log, blend_dir + result_image, comp_time = perform_blending( + result_image, result_mask, full_res_fragments, log, parameters ) log.log(parameters["my_level"], f" > finished in {comp_time} mins!\n") - result = result_image - # Save final result - log.log(parameters["my_level"], "Saving full resolution result") + log.log(parameters["my_level"], f"Saving high resolution result at {parameters['output_res']} µm/pixel") start = time.time() - result.set_progress(True) - result.signal_connect("eval", image_eval_handler) - result.tiffsave( - str(parameters["sol_save_dir"].joinpath("highres", "fullres_image.tif")), + result_image.set_progress(True) + result_image.signal_connect("eval", image_eval_handler) + result_image.write_to_file( + str(parameters["sol_save_dir"].joinpath("highres", f"stitched_image_{parameters['output_res']}_micron.tif")), tile=True, compression="jpeg", bigtiff=True, @@ -564,7 +393,7 @@ def generate_full_res(parameters, log): Q=20, ) log.log( - parameters["my_level"], f" > finished in {int((time.time()-start)/60)} mins!\n" + parameters["my_level"], f" > finished in {int(np.ceil((time.time()-start)/60))} mins!\n" ) return diff --git a/src/pythostitcher_utils/fuse_images_highres.py b/src/pythostitcher_utils/fuse_images_highres.py index 7f74e2a..f4de8cf 100644 --- a/src/pythostitcher_utils/fuse_images_highres.py +++ b/src/pythostitcher_utils/fuse_images_highres.py @@ -129,20 +129,21 @@ def fuse_images_highres(images, masks): names = list(images.keys()) combinations = itertools.combinations(names, 2) + # Possible combination pairs hor_combinations = [ - ["ul", "ur"], - ["ul", "lr"], - ["ll", "ur"], - ["ll", "lr"], - ["left", "right"] - ] + ["ul", "ur"], + ["ul", "lr"], + ["ll", "ur"], + ["ll", "lr"], + ["left", "right"] + ] ver_combinations = [ - ["ul", "ll"], - ["ul", "lr"], - ["ur", "ll"], - ["ur", "lr"], - ["top", "bottom"] - ] + ["ul", "ll"], + ["ul", "lr"], + ["ur", "ll"], + ["ur", "lr"], + ["top", "bottom"] + ] # Create some lists for iterating total_mask = np.sum(list(masks.values()), axis=0).astype("uint8") diff --git a/src/pythostitcher_utils/gradient_blending.py b/src/pythostitcher_utils/gradient_blending.py new file mode 100644 index 0000000..ae46073 --- /dev/null +++ b/src/pythostitcher_utils/gradient_blending.py @@ -0,0 +1,171 @@ +import multiresolutionimageinterface as mir +import numpy as np +import cv2 +import pyvips +import matplotlib.pyplot as plt + + +def perform_blending(result_image, result_mask, full_res_fragments, log, parameters): + """ + Function to blend areas of overlap. + + Inputs + - Full resolution image + - Full resolution mask + - All fragments + - Logging instance + - Directory to save blending results + + Output + - Full resolution blended image + """ + + # Get dimensions of image + width = result_mask.width + height = result_mask.height + + # Specify tile sizes + tilesize = 4000 + num_xtiles = int(np.ceil(width / tilesize)) + num_ytiles = int(np.ceil(height / tilesize)) + + start = time.time() + + # Loop over columns + for x in range(num_xtiles): + + log.log(45, f" - blending column {x+1}/{num_xtiles}") + + # Loop over rows + for y in range(num_ytiles): + + # To extract tiles via pyvips we need the starting X and Y and the tilesize. + # This tilesize will differ based on whether we can retrieve a full square + # tile, which is only not possible near the right and bottom edge. The tile + # in these cases will be smaller. + + # Case for lower right corner + if x == num_xtiles - 1 and y == num_ytiles - 1: + new_tilesize = [width - x * tilesize, height - y * tilesize] + # Case for right edge + elif x == num_xtiles - 1 and y < num_ytiles - 1: + new_tilesize = [width - x * tilesize, tilesize] + # Case for bottom edge + elif x < num_xtiles - 1 and y == num_ytiles - 1: + new_tilesize = [tilesize, height - y * tilesize] + # Regular cases + else: + new_tilesize = [tilesize, tilesize] + + # Only perform bending in case of overlap + tile_mask = result_mask.crop( + x * tilesize, y * tilesize, new_tilesize[0], new_tilesize[1] + ) + + if tile_mask.max() > 1: + + # Extract the corresponding image and mask for all fragments + images = dict() + masks = dict() + for f in full_res_fragments: + image_patch = f.outputres_mask.crop( + x * tilesize, y * tilesize, new_tilesize[0], new_tilesize[1], + ) + image = np.ndarray( + buffer=image_patch.write_to_memory(), + dtype=np.uint8, + shape=[new_tilesize[1], new_tilesize[0], image_patch.bands], + ) + + mask_patch = f.outputres_mask.crop( + x * tilesize, y * tilesize, new_tilesize[0], new_tilesize[1], + ) + mask = np.ndarray( + buffer=mask_patch.write_to_memory(), + dtype=np.uint8, + shape=[new_tilesize[1], new_tilesize[0]], + ) + + images[f.orientation] = image + masks[f.orientation] = mask + + # Perform the actual blending + blend, grad, overlap_fragments, valid = fuse_images_highres( + images, masks + ) + + if valid: + + # Get overlap contours for plotting + overlap = (~np.isnan(grad) * 255).astype("uint8") + overlap_cnts, _ = cv2.findContours( + overlap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE, + ) + + # Show and save blended result + plt.figure(figsize=(12, 10)) + plt.suptitle(f"Result for row {y} and col {x}", fontsize=24) + plt.subplot(231) + plt.title(f"Mask fragment '{overlap_fragments[0]}'", fontsize=20) + plt.imshow(masks[overlap_fragments[0]], cmap="gray") + plt.axis("off") + plt.clim([0, 1]) + plt.subplot(232) + plt.title(f"Mask fragment '{overlap_fragments[1]}'", fontsize=20) + plt.imshow(masks[overlap_fragments[1]], cmap="gray") + plt.axis("off") + plt.clim([0, 1]) + plt.subplot(233) + plt.title("Mask overlap + gradient", fontsize=20) + plt.imshow( + (masks[overlap_fragments[0]] + masks[overlap_fragments[1]])==2, + cmap="gray", + ) + plt.imshow(grad, cmap="jet", alpha=0.5) + plt.axis("off") + plt.colorbar(fraction=0.046, pad=0.04) + plt.subplot(234) + plt.title(f"Image fragment '{overlap_fragments[0]}'", fontsize=20) + plt.imshow(images[overlap_fragments[0]]) + for cnt in overlap_cnts: + cnt = np.squeeze(cnt) + plt.plot(cnt[:, 0], cnt[:, 1], c="r", linewidth=3) + plt.axis("off") + plt.subplot(235) + plt.title(f"Image fragment '{overlap_fragments[1]}'", fontsize=20) + plt.imshow(images[overlap_fragments[1]]) + for cnt in overlap_cnts: + cnt = np.squeeze(cnt) + plt.plot(cnt[:, 0], cnt[:, 1], c="r", linewidth=3) + plt.axis("off") + plt.subplot(236) + plt.title("Blend image", fontsize=20) + plt.imshow(blend) + for cnt in overlap_cnts: + cnt = np.squeeze(cnt) + plt.plot(cnt[:, 0], cnt[:, 1], c="r", linewidth=3) + plt.axis("off") + plt.tight_layout() + plt.savefig( + f"{parameters['blend_dir']}/row{str(y).zfill(4)}_col{str(x).zfill(4)}.png" + ) + plt.close() + + # Insert blended image + h, w = blend.shape[:2] + bands = 3 + dformat = "uchar" + blend_image = pyvips.Image.new_from_memory( + blend.ravel(), w, h, bands, dformat + ) + + result_image = result_image.insert( + blend_image, x * new_tilesize[0], y * new_tilesize[1] + ) + + else: + continue + + comp_time = int(np.ceil((time.time() - start) / 60)) + + return result_image, comp_time \ No newline at end of file