Skip to content

Commit

Permalink
fix high cell count
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinjohncutler committed Nov 7, 2024
1 parent 262de88 commit 571ee03
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 18 deletions.
8 changes: 4 additions & 4 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "30a3c190-5045-4840-b4c2-6d263b7a3178",
"id": "0",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -22,7 +22,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "8f85ebf1-19a9-4bd7-afca-e077d99f9018",
"id": "1",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -107,7 +107,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a3fd6807-a5d6-408a-a8a1-fae9dcfa1d5b",
"id": "2",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -123,7 +123,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "98835a5d-081e-4bee-b2bc-595e01638033",
"id": "3",
"metadata": {},
"outputs": [],
"source": [
Expand Down
240 changes: 226 additions & 14 deletions ncolor/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from numba import njit
import scipy
from .format_labels import format_labels, is_sequential
from skimage.segmentation import expand_labels as skimage_expand_labels
# from skimage.segmentation import expand_labels as skimage_expand_labels
import edt
from scipy.ndimage import distance_transform_edt


def label(lab, n=4, conn=2, max_depth=5, offset=0, expand=None, return_n=False):
def label(lab, n=4, conn=2, max_depth=5, offset=0, expand=None, return_n=False, greedy=False):
# needs to be in standard label form
# but also needs to be in int32 data type to work properly; the formatting automatically
# puts it into the smallest datatype to save space
Expand All @@ -26,7 +26,7 @@ def label(lab, n=4, conn=2, max_depth=5, offset=0, expand=None, return_n=False):
lab = expand_labels(lab)
# lab = np.pad(format_labels(lab),pad)
lab = format_labels(np.pad(lab,pad),background=0)
lut = get_lut(lab,n,conn,max_depth,offset,return_n)
lut = get_lut(lab,n,conn,max_depth,offset,greedy)

nc = lut[lab][unpad]*mask

Expand All @@ -35,15 +35,22 @@ def label(lab, n=4, conn=2, max_depth=5, offset=0, expand=None, return_n=False):
else:
return nc

def get_lut(lab, n=4, conn=2, max_depth=5, offset=0, return_n=False):
lab = format_labels(lab).astype(np.int32)
def get_lut(lab, n=4, conn=2, max_depth=5, offset=0, greedy=False):
# lab = format_labels(lab).astype(np.int32)
lab = format_labels(lab).astype(np.int64)

idx = connect(lab, conn)
idx = mapidx(idx)
colors = render_net(idx, n=n, rand=10, max_depth=max_depth, offset=offset)
if greedy:
colors = greedy_coloring(idx)
else:
colors = render_net(idx, n=n, rand=10, max_depth=max_depth, offset=offset)

lut = np.ones(lab.max()+1, dtype=np.uint8)
for i in colors: lut[i] = colors[i]
lut[0] = 0
return lut


def neighbors(shape, conn=1):
dim = len(shape)
Expand Down Expand Up @@ -72,6 +79,7 @@ def search(img, nbs):
def connect(img, conn=1):
buf = np.pad(img, 1, 'constant')
nbs = neighbors(buf.shape, conn)
# rst = search(buf, nbs)
rst = search(buf, nbs)
if len(rst)<2:
return rst
Expand All @@ -86,17 +94,43 @@ def connect(img, conn=1):
return rst[order][idx]

# maybe replace this with fastremap
import fastremap
def mapidx(idx):
dic = {}
for i in np.unique(idx): dic[i] = []
# for i in np.unique(idx): dic[i] = []
for i in fastremap.unique(idx): dic[i] = [] # marginally faster
for i,j in idx:
dic[i].append(j)
dic[j].append(i)
return dic

def mapidx(idx):
# Stack idx and its reversed version to account for both directions
idx_rev = idx[:, [1, 0]]
idx_all = np.vstack((idx, idx_rev))

# Sort idx_all by the first column (i)
order = np.argsort(idx_all[:, 0])
idx_all_sorted = idx_all[order]

i = idx_all_sorted[:, 0]
j = idx_all_sorted[:, 1]

# Find unique 'i's and the indices where they occur
unique_i, indices = fastremap.unique(i, return_index=True)

# Split 'j' into lists according to the indices
splits = np.split(j, indices[1:])

# Build the dictionary mapping each 'i' to its list of neighbors
dic = dict(zip(unique_i, splits))
return dic

# create a connection mapping
def render_net(conmap, n=4, rand=12, depth=0, max_depth=5, offset=0):
thresh = 1e4
# LARGE_INT = len(conmap)+1 # minimal to work, doesn't look as good?
LARGE_INT = len(conmap)*2 # get back to previous behavior
thresh = LARGE_INT
if depth<max_depth:
nodes = list(conmap.keys())
np.random.seed(depth+1+offset)
Expand All @@ -108,31 +142,100 @@ def render_net(conmap, n=4, rand=12, depth=0, max_depth=5, offset=0):
count+=1
k = nodes.pop(0)
counter[k] += 1
hist = [1e4] + [0] * n
hist = [LARGE_INT] + [0] * n
for p in conmap[k]:
hist[colors[p]] += 1
if min(hist)==0:
colors[k] = hist.index(min(hist))
counter[k] = 0
continue
hist[colors[k]] = 1e4
hist[colors[k]] = LARGE_INT
minc = hist.index(min(hist))
if counter[k]==rand:
counter[k] = 0
np.random.seed(count)
minc = np.random.randint(1,4)
minc = np.random.randint(1,n+1)

colors[k] = minc
for p in conmap[k]:
if colors[p] == minc:
nodes.append(p)
if count==thresh:
print(n,'-color algorthm failed,trying again with',n+1,'colors. Depth',depth)
# print(n,'-color algorthm failed,trying again with',n+1,'colors. Depth',depth)
colors = render_net(conmap,n+1,rand,depth+1,max_depth, offset)
return colors

import numpy as np
from collections import deque
# slightly faster
def render_net(conmap, n=4, rand=12, depth=0, max_depth=5, offset=0):
LARGE_INT = len(conmap) * 2
thresh = LARGE_INT
if depth < max_depth:
nodes = deque(conmap.keys())
np.random.seed(depth + 1 + offset)
nodes = deque(np.random.permutation(list(nodes)))
colors = dict.fromkeys(nodes, 0)
counter = dict.fromkeys(nodes, 0)
count = 0

# Preallocate hist outside the loop
hist = [-1] + [0] * (n)

while nodes and count < thresh:
count += 1
k = nodes.popleft()
counter_k = counter[k] = counter[k] + 1

# Reset hist inside the loop
hist = [hist[0]] + [0] * n

for p in conmap[k]:
hist[colors[p]] += 1

# this seems to be the key block distinguishing it from greedy-like coloring
min_hist = min(hist)
if min_hist == 0:
min_color = hist.index(min_hist)
colors[k] = min_color
counter[k] = 0
continue

hist[colors[k]] = LARGE_INT
min_hist = min(hist)
minc = hist.index(min_hist)

if counter_k == rand:
counter[k] = 0
minc = np.random.randint(1, n + 1)

colors[k] = minc

for p in conmap[k]:
if colors[p] == minc:
nodes.append(p)

if count == thresh:
# Recursive call with increased colors
colors = render_net(conmap, n + 1, rand, depth + 1, max_depth, offset)
return colors
else:
print('N-color algorthm exceeded max depth of',max_depth)
print(f"N-color algorithm exceeded max depth of {max_depth}")
return None


def greedy_coloring(conmap):
# faster and uses fewer colors than render_net
colors = {}
for node in conmap:
neighbor_colors = {colors.get(neigh) for neigh in conmap[node] if neigh in colors}
for color in range(1, len(conmap) + 1):
if color not in neighbor_colors:
colors[node] = color
break
return colors


def expand_labels(label_image):
"""
Sped-up version of the scikit-image function just by dropping the distance thresholding.
Expand All @@ -141,4 +244,113 @@ def expand_labels(label_image):
nearest_label_coords = distance_transform_edt(label_image==0,
return_distances=False,
return_indices=True)
return label_image[tuple(nearest_label_coords)]
return label_image[tuple(nearest_label_coords)]



# attempts to not use njit

# def search2(img, nbs):
# line = img.ravel()
# len_line = len(line)
# nz_indices = np.flatnonzero(line) # Indices where line is non-zero

# N = len(nz_indices) # Number of non-zero elements
# D = len(nbs) # Number of neighbor offsets

# # Create repeated indices for 'i' and 'j'
# i_repeat = np.repeat(nz_indices, D) # Shape: (N * D,)
# d_tile = np.tile(nbs, N) # Shape: (N * D,)
# j = i_repeat + d_tile # Neighbor indices, Shape: (N * D,)

# # Filter out invalid neighbor indices
# valid_mask = (j >= 0) & (j < len_line)
# i_valid = i_repeat[valid_mask]
# j_valid = j[valid_mask]

# # Get the labels at the valid indices
# line_i = line[i_valid]
# line_j = line[j_valid]

# # Apply the conditions:
# # - Neighbor is non-zero
# # - Labels are different
# mask = (line_j != 0) & (line_i != line_j)

# # Collect the valid pairs
# pairs = np.column_stack((line_i[mask], line_j[mask]))

# return pairs


# def search2(img, conn=1):
# coords = np.array(np.nonzero(img)) # Convert to a NumPy array
# npix = coords.shape[1] # Number of non-zero pixels
# dim = img.ndim
# shape = img.shape

# # Define neighbor offsets
# from scipy.ndimage import generate_binary_structure
# structure = generate_binary_structure(dim, conn)
# structure[tuple([1]*dim)] = 0 # Remove the center
# neighbor_offsets = np.array(np.nonzero(structure)) - 1 # Offsets relative to center
# n_neighbors = neighbor_offsets.shape[1]

# # Compute neighbor coordinates
# # Expand coords to shape (dim, npix, 1)
# coords_expanded = coords[:, :, np.newaxis] # Shape: (dim, npix, 1)
# # Broadcast neighbor_offsets to (dim, 1, n_neighbors) and add
# neighbor_coords = coords_expanded + neighbor_offsets[:, np.newaxis, :] # Shape: (dim, npix, n_neighbors)

# # Reshape to 2D arrays for easier indexing
# neighbor_coords = neighbor_coords.reshape(dim, -1) # Shape: (dim, npix * n_neighbors)
# center_coords = np.repeat(coords_expanded, n_neighbors, axis=2).reshape(dim, -1) # Shape: (dim, npix * n_neighbors)

# # Handle out-of-bounds coordinates
# valid_mask = np.all((neighbor_coords >= 0) & (neighbor_coords < np.array(shape)[:, np.newaxis]), axis=0)

# # Filter valid neighbor coordinates
# valid_neighbor_coords = neighbor_coords[:, valid_mask]
# valid_center_coords = center_coords[:, valid_mask]

# # Map coordinates to flat indices
# neighbor_indices = np.ravel_multi_index(valid_neighbor_coords, shape)
# center_indices = np.ravel_multi_index(valid_center_coords, shape)

# # Get labels at indices
# line = img.ravel()
# labels_center = line[center_indices]
# labels_neighbor = line[neighbor_indices]

# # Filter valid pairs
# valid_pairs_mask = (labels_neighbor != 0) & (labels_neighbor != labels_center)

# # Collect valid label pairs
# pairs = np.column_stack((labels_center[valid_pairs_mask], labels_neighbor[valid_pairs_mask]))

# return pairs


# import fastremap

# def connect(img, conn=1):
# buf = np.pad(img, 1, 'constant')
# rst = search2(buf, conn)
# if len(rst) < 2:
# return rst
# # Remove duplicates and sort the pairs
# rst = fastremap.unique(np.sort(rst, axis=1), axis=0)
# return rst

# using fastremap is a lot slower?
# def connect(img, conn=1):
# buf = np.pad(img, 1, 'constant')
# nbs = neighbors(buf.shape, conn)
# rst = search(buf, nbs)
# if len(rst) < 2:
# return rst
# rst.sort(axis=1)
# print(rst.shape)
# # Use np.unique to find unique rows (label pairs)
# rst_unique = fastremap.unique(rst, axis=0)
# return rst_unique

0 comments on commit 571ee03

Please sign in to comment.