Skip to content

Commit

Permalink
Bring over more dependencies.
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinjohncutler committed Jun 24, 2022
1 parent 17f75c0 commit cde12f5
Showing 1 changed file with 69 additions and 48 deletions.
117 changes: 69 additions & 48 deletions ncolor/format_labels.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,14 @@
import numpy as np
from scipy.ndimage.morphology import binary_dilation, binary_erosion
import mahotas as mh
import fastremap


try:
from skimage import measure
from skimage.morphology import remove_small_holes
SKIMAGE_ENABLED = True
except:
SKIMAGE_ENABLED = False
# # Should work for 3D too. Could put into usigned integer form at the end...
# # Also could use some parallelization
# def format_labels(labels, clean=False, min_area=9):
# """
# Puts labels into 'standard form', i.e. background=0 and cells 1,2,3,...,N-1,N.
# Optional clean flag: disconnect and disjoint masks and discard small masks beflow min_area.
# min_area default is 9px.
# """

# # Labels are stored as a part of a float array in Cellpose, so it must be cast back here.
# # some people also use -1 as background, so we must cast to the signed integar class. We
# # can safely assume no 2D or 3D image will have more than 2^31 cells. Finally, cv2 does not
# # play well with unsigned integers (saves to default uint8), so we cast to uint32.
# labels = labels.astype('int32')
# labels -= np.min(labels)
# labels = labels.astype('uint32')

# # optional cleanup
# if clean:
# inds = np.unique(labels)
# for j in inds[inds>0]:
# mask = labels==j
# lbl = measure.label(mask)
# regions = measure.regionprops(lbl)
# regions.sort(key=lambda x: x.area, reverse=True)
# if len(regions) > 1:
# print('Warning - found mask with disjoint label.')
# for rg in regions[1:]:
# if rg.area <= min_area:
# labels[rg.coords[:,0], rg.coords[:,1]] = 0
# print('secondary disjoint part smaller than min_area. Removing it.')
# else:
# print('secondary disjoint part bigger than min_area, relabeling. Area:',rg.area,
# 'Label value:',np.unique(labels[rg.coords[:,0], rg.coords[:,1]]))
# labels[rg.coords[:,0], rg.coords[:,1]] = np.max(labels)+1

# rg0 = regions[0]
# if rg0.area <= min_area:
# labels[rg0.coords[:,0], rg0.coords[:,1]] = 0
# print('Warning - found mask area less than', min_area)
# print('Removing it.')

# if np.any(labels):
# fastremap.renumber(labels,in_place=True) # convenient to have unit increments from 1 to N cells
# labels = fastremap.refit(labels) # put into smaller data type if possible
# return labels


def format_labels(labels, clean=False, min_area=9, despur=False, verbose=False, ignore=False):
"""
Expand Down Expand Up @@ -117,3 +71,70 @@ def format_labels(labels, clean=False, min_area=9, despur=False, verbose=False,
fastremap.renumber(labels,in_place=True) # convenient to have unit increments from 1 to N cells
labels = fastremap.refit(labels) # put into smaller data type if possible
return labels

def delete_spurs(mask):
pad = 1
#must fill single holes in image to avoid cusps causing issues. Will limit to holes of size ___
skel = remove_small_holes(np.pad(mask,pad,mode='constant'),5)

nbad = 1
niter = 0
while (nbad > 0):
bad_points = endpoints(skel)
skel = np.logical_and(skel,np.logical_not(bad_points))
nbad = np.sum(bad_points)
niter+=1

unpad = tuple([slice(pad,-pad)]*skel.ndim)
skel = skel[unpad] #unpad

return skel

# this still only works for 2D
def endpoints(skel):
pad = 1 # appears to require padding to work properly....
skel = np.pad(skel,pad)
endpoint1=np.array([[0, 0, 0],
[0, 1, 0],
[2, 1, 2]])

endpoint2=np.array([[0, 0, 0],
[0, 1, 2],
[0, 2, 1]])

endpoint3=np.array([[0, 0, 2],
[0, 1, 1],
[0, 0, 2]])

endpoint4=np.array([[0, 2, 1],
[0, 1, 2],
[0, 0, 0]])

endpoint5=np.array([[2, 1, 2],
[0, 1, 0],
[0, 0, 0]])

endpoint6=np.array([[1, 2, 0],
[2, 1, 0],
[0, 0, 0]])

endpoint7=np.array([[2, 0, 0],
[1, 1, 0],
[2, 0, 0]])

endpoint8=np.array([[0, 0, 0],
[2, 1, 0],
[1, 2, 0]])

ep1=mh.morph.hitmiss(skel,endpoint1)
ep2=mh.morph.hitmiss(skel,endpoint2)
ep3=mh.morph.hitmiss(skel,endpoint3)
ep4=mh.morph.hitmiss(skel,endpoint4)
ep5=mh.morph.hitmiss(skel,endpoint5)
ep6=mh.morph.hitmiss(skel,endpoint6)
ep7=mh.morph.hitmiss(skel,endpoint7)
ep8=mh.morph.hitmiss(skel,endpoint8)
ep = ep1+ep2+ep3+ep4+ep5+ep6+ep7+ep8
unpad = tuple([slice(pad,-pad)]*ep.ndim)
ep = ep[unpad]
return ep

0 comments on commit cde12f5

Please sign in to comment.