Skip to content

Commit

Permalink
refactor instance seg processing for 2d and 3d (#306)
Browse files Browse the repository at this point in the history
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Nov 2, 2023
1 parent 7c0a47e commit 94285f0
Showing 1 changed file with 57 additions and 60 deletions.
117 changes: 57 additions & 60 deletions cyto_dl/models/im2im/utils/instance_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self.dim = dim
self.allow_missing_keys = allow_missing_keys
self.kernel_size = kernel_size
self.anisotropy = torch.as_tensor([anisotropy, 1, 1])
self.anisotropy = torch.as_tensor([anisotropy if dim == 3 else 1] + [1] * (dim - 1))
self.thin = thin

def shrink(self, im):
Expand All @@ -75,12 +75,12 @@ def shrink(self, im):

def skeleton_tall(self, img, max_label):
"""Skeletonize 3d image with increased thickness in z."""
if max_label == 0:
if max_label == 0 or self.dim == 2:
return skeletonize(img)
tall_skeleton = np.stack([skeletonize(np.max(img, 0))] * img.shape[0])
return tall_skeleton

def label_2d(self, img):
def label_slice(self, img):
out = np.zeros_like(img, dtype=np.int16)
for z in range(img.shape[0]):
lab = label(img[z])
Expand All @@ -93,7 +93,7 @@ def topology_preserving_thinning(self, bw, min_size=100):
Use skeleton to bridge gaps created by erosion.
"""
selem = ball(self.thin)[:: self.thin]
selem = ball(self.thin)[:: int(self.anisotropy[0])] if self.dim == 3 else disk(self.thin)
eroded = erosion(bw, selem)
# only want to preserve connections between significantly-sized objects

Expand All @@ -105,14 +105,13 @@ def topology_preserving_thinning(self, bw, min_size=100):
return eroded

skel = self.skeleton_tall(bw, max_label)

if max_label == 0:
return skel

# if erosion separates object into multiple pieces, use skeleton to bridge those pieces into single object
# 1. isolate pieces of skeleton that are outside of eroded objects (i.e. could bridge between objects)
skel[eroded != 0] = 0
skel = self.label_2d(skel)
skel = self.label_slice(skel) if self.dim == 3 else label(skel)

for i in np.unique(skel)[1:]:
# 3. find number of non-background objects overlapped by piece of skeleton, add back in pieces that overlap multiple obj
Expand All @@ -134,9 +133,7 @@ def _get_point_embeddings(self, object_points, skeleton_points):

def smooth_embedding(self, embedding):
"""Smooths embedding by convolving with a mean kernel, excluding non-object pixels."""
kernel = np.ones((self.kernel_size, self.kernel_size, self.kernel_size)) / (
self.kernel_size**3
)
kernel = np.ones([self.kernel_size] * self.dim) / self.kernel_size**self.dim
nan_embed = embedding.clone()
nan_embed[nan_embed == 0] = torch.nan
for i in range(embedding.shape[0]):
Expand All @@ -148,33 +145,36 @@ def smooth_embedding(self, embedding):
def embed_from_skel(self, skel, iseg):
"""Find per-pixel embedding vector to closest point on skeleton."""
iseg[skel != 0] = 0
embed = torch.zeros(3, iseg.shape[0], iseg.shape[1], iseg.shape[2])
skel_boundary = (
torch.from_numpy(find_boundaries(skel.numpy(), mode="inner")) * skel
) # propagate labels

# 3ZYX vector field for 3d, 2YX for 2d
embed = torch.zeros([self.dim] + [iseg.shape[i] for i in range(self.dim)])

# propagate labels to boundaries
skel_boundary = torch.from_numpy(find_boundaries(skel.numpy(), mode="inner")) * skel
for i in np.unique(iseg)[1:]:
object_points = iseg.eq(i).nonzero()
skel_points = skel_boundary.eq(i).nonzero()
object_mask = iseg.eq(i)
# distances should take into account z anisotropy
object_points = object_mask.nonzero().mul(self.anisotropy)
skel_points = skel_boundary.eq(i).nonzero().mul(self.anisotropy)
if skel_points.numel() == 0:
continue
# distances should take into account z anisotropy
point_embeddings = self._get_point_embeddings(
object_points.mul(self.anisotropy), skel_points.mul(self.anisotropy)
)
embed[:, object_points.T[0], object_points.T[1], object_points.T[2]] = point_embeddings
point_embeddings = self._get_point_embeddings(object_points, skel_points)
embed[:, object_mask] = point_embeddings
# smooth sharp transitions from spatial embedding
embed = self.smooth_embedding(embed)

# turn spatial embedding into offset vector by subtracting pixel coordinates
anisotropic_shape = torch.as_tensor(iseg.shape).mul(self.anisotropy)
coordinates = torch.stack(
torch.meshgrid(
torch.linspace(0, anisotropic_shape[0] - 1, iseg.shape[0]),
torch.linspace(0, anisotropic_shape[1] - 1, iseg.shape[1]),
torch.linspace(0, anisotropic_shape[2] - 1, iseg.shape[2]),
*[
torch.linspace(0, anisotropic_shape[i] - 1, iseg.shape[i])
for i in range(self.dim)
]
)
)
embed[embed != 0] -= coordinates[embed != 0]
embed_pts = embed.ne(0)
embed[embed_pts] -= coordinates[embed_pts]
return embed

def _get_object_contacts(self, img):
Expand All @@ -191,10 +191,12 @@ def _get_cmap(self, skel_edt, im):
"""Create costmap to increase loss in boundary areas."""
points_with_vecs = im.clone().squeeze()
points_with_vecs[skel_edt > 0] = 0
# emphasize very thin areas
add_in_thin = np.logical_and(skel_edt > 0, skel_edt < 3)
# emphasize areas where vector field is nonzero
points_with_vecs = np.logical_or(points_with_vecs, add_in_thin)
sigma = torch.as_tensor([2, 2, 2]) / self.anisotropy
sigma = torch.max(sigma, torch.ones(3)).numpy()
sigma = torch.as_tensor([2] * self.dim) / self.anisotropy
sigma = torch.max(sigma, torch.ones(self.dim)).numpy()
cmap = gaussian(points_with_vecs > 0, sigma=sigma)
# emphasize boundary points
cmap /= cmap.max()
Expand Down Expand Up @@ -274,7 +276,6 @@ def __init__(

def _flip(self, img, is_label):
img = self.flipper(img)

if is_label:
assert (
img.shape[0] == 4 + self.dim
Expand Down Expand Up @@ -307,7 +308,6 @@ def __init__(self, dim: int = 3):
dim:int=3
Spatial dimension of input images.
"""

self.dim = dim
self.skeleton_loss = CMAP_loss(torch.nn.MSELoss(reduction="none"))
self.vector_loss = CMAP_loss(torch.nn.MSELoss(reduction="none"))
Expand Down Expand Up @@ -347,37 +347,39 @@ class InstanceSegCluster:

def __init__(
self,
dim: int = 3,
anisotropy: float = 2.6,
skel_threshold: float = 0,
semantic_threshold: float = 0,
min_size: int = 1000,
distance_threshold: int = 100,
):
self.anisotropy = anisotropy
self.dim = dim
self.anisotropy = torch.as_tensor([anisotropy if dim == 3 else 1] + [1] * (dim - 1))
self.skel_threshold = skel_threshold
self.semantic_threshold = semantic_threshold
self.min_size = min_size
self.distance_threshold = distance_threshold

def _get_point_embeddings(self, object_points, skeleton_points):
"""
object_points: (N, dim) array of embedded points from semantic segmentation
skeleton_points: (N, dim) array of points on skeleton boundary
"""
tree = KDTree(skeleton_points)
dist, idx = tree.query(object_points)
return dist, tree.data[idx].T.astype(int)

def kd_clustering(self, embed_z, embed_y, embed_x, skel):
def kd_clustering(self, embeddings, skel):
"""assign embedded points to closest skeleton."""
skel = find_boundaries(skel, mode="inner") * skel # propagate labels
skel = find_boundaries(skel, mode="inner") * skel # propagate labels to boundaries
skel_points = np.stack(skel.nonzero()).T
embed_points = torch.stack((embed_z, embed_y, embed_x)).numpy()
embed_points = np.stack(embeddings).T
(
dist_to_closest_skel,
closest_skel_point_to_embedding,
) = self._get_point_embeddings(embed_points.T, skel_points)
embedding_labels = skel[
closest_skel_point_to_embedding[0],
closest_skel_point_to_embedding[1],
closest_skel_point_to_embedding[2],
]
) = self._get_point_embeddings(embed_points, skel_points)
embedding_labels = skel[tuple(closest_skel_point_to_embedding[:3])]
# remove points too far from any skeleton
embedding_labels[dist_to_closest_skel > self.distance_threshold] = 0
return embedding_labels
Expand All @@ -388,19 +390,19 @@ def _get_largest_cc(self, im):
return im == largest_cc

def __call__(self, image):
image = image.cpu()
image = image.detach().cpu()
skel = image[0].numpy()
semantic = image[1]
embedding = image[2:5]
embedding = image[2 : 2 + self.dim].float()
# z embeddings are anisotropic, have to adjust to coordinates in real space, not pixel space
anisotropic_shape = torch.as_tensor(semantic.shape).mul(
torch.as_tensor([self.anisotropy, 1, 1])
)
anisotropic_shape = torch.as_tensor(semantic.shape).mul(self.anisotropy)

coordinates = torch.stack(
torch.meshgrid(
torch.linspace(0, anisotropic_shape[0] - 1, semantic.shape[0]),
torch.linspace(0, anisotropic_shape[1] - 1, semantic.shape[1]),
torch.linspace(0, anisotropic_shape[2] - 1, semantic.shape[2]),
*[
torch.linspace(0, anisotropic_shape[i] - 1, semantic.shape[i])
for i in range(self.dim)
]
)
)
embedding += coordinates
Expand All @@ -413,25 +415,20 @@ def __call__(self, image):
skel = remove_small_objects(skel, self.min_size)

semantic = semantic > self.semantic_threshold
# if only one skeleton, return largest connected component of semantic segmentation
if len(np.unique(skel)) == 2:
return self._get_largest_cc(semantic)
return self._get_largest_cc(semantic).astype(np.uint8)

out = np.zeros_like(semantic, dtype=np.uint16)
semantic_points = semantic.nonzero().T

# find pixel coordinates pointed to by each z, y, x point within semantic segmentation
embed_z = embedding[0][semantic_points[0], semantic_points[1], semantic_points[2]]
embed_z /= self.anisotropy
embed_z = embed_z.clip(0, semantic.shape[0] - 1).round().int()

embed_y = embedding[1][semantic_points[0], semantic_points[1], semantic_points[2]]
embed_y = embed_y.clip(0, semantic.shape[1] - 1).round().int()

embed_x = embedding[2][semantic_points[0], semantic_points[1], semantic_points[2]]
embed_x = embed_x.clip(0, semantic.shape[2] - 1).round().int()
embeddings = []
for i in range(embedding.shape[0]):
dim_embed = embedding[i][semantic] / self.anisotropy[i]
dim_embed = dim_embed.clip(0, semantic.shape[i] - 1).round().int()
embeddings.append(dim_embed)

# assign each embedded point the label of the closest skeleton
labeled_embed = self.kd_clustering(embed_z, embed_y, embed_x, skel)
labeled_embed = self.kd_clustering(embeddings, skel)
# propagate embedding label to semantic segmentation
out[semantic_points[0], semantic_points[1], semantic_points[2]] = labeled_embed
out[semantic] = labeled_embed
return out

0 comments on commit 94285f0

Please sign in to comment.