Skip to content

Commit

Permalink
refactor: updated fragment filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Nov 13, 2024
1 parent 8637c2e commit 249abee
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 63 deletions.
160 changes: 98 additions & 62 deletions src/deep_neurographs/fragment_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,98 @@
other from a FragmentsGraph.
"""

from collections import defaultdict
from tqdm import tqdm

import networkx as nx
import numpy as np
from networkx import connected_components
from tqdm import tqdm

from deep_neurographs import geometry
from deep_neurographs.utils import util

COLOR = "1.0 0.0 0.0"
QUERY_DIST = 15


# --- Curvy Removal ---
def remove_curvy(graph, max_length, ratio=0.5):
def remove_curvy(fragments_graph, max_length, ratio=0.5):
"""
Removes connected components with 2 nodes from "fragments_graph" that are
"curvy" fragments, based on a specified ratio of endpoint distance to edge
length and a maximum length threshold.
Parameters
----------
fragments_graph : FragmentsGraph
Graph generated from fragments of a predicted segmentation.
max_length : float
The maximum allowable length (in microns) for an edge to be considered
for removal.
ratio : float, optional
Threshold ratio of endpoint distance to edge length. Components with a
ratio below this value are considered "curvy" and are removed. The
default is 0.5.
Returns
-------
int
Number of fragments removed from the graph.
"""
deleted_ids = set()
components = [c for c in connected_components(graph) if len(c) == 2]
for nodes in tqdm(components, desc="Filter Curvy Fragment"):
if len(nodes) == 2:
i, j = tuple(nodes)
length = graph.edges[i, j]["length"]
endpoint_dist = graph.dist(i, j)
if endpoint_dist / length < ratio and length < max_length:
deleted_ids.add(graph.edges[i, j]["swc_id"])
delete_fragment(graph, i, j)
components = get_line_components(fragments_graph)
for nodes in tqdm(components, desc="Filter Curvy Fragments"):
i, j = tuple(nodes)
length = fragments_graph.edges[i, j]["length"]
endpoint_dist = fragments_graph.dist(i, j)
if endpoint_dist / length < ratio and length < max_length:
deleted_ids.add(fragments_graph.edges[i, j]["swc_id"])
delete_fragment(fragments_graph, i, j)
return len(deleted_ids)


# --- Doubles Removal ---
def remove_doubles(graph, max_length, node_spacing, output_dir=None):
def remove_doubles(fragments_graph, max_length, node_spacing):
"""
Removes connected components from "neurgraph" that are likely to be a
double.
Removes connected components from "fragments_graph" that are likely to be
a double -- caused by ghosting in the image.
Parameters
----------
graph : FragmentsGraph
fragments_graph : FragmentsGraph
Graph to be searched for doubles.
max_length : int
Maximum size of connected components to be searched.
node_spacing : int
Expected distance in microns between nodes in "graph".
output_dir : str or None, optional
Directory that doubles will be written to. The default is None.
Expected distance (in microns) between nodes in "fragments_graph".
Returns
-------
graph
Graph with doubles removed.
int
Number of fragments removed from graph.
"""
# Initializations
components = [c for c in connected_components(graph) if len(c) == 2]
components = get_line_components(fragments_graph)
deleted_ids = set()
kdtree = graph.get_kdtree()
if output_dir:
util.mkdir(output_dir, delete=True)
kdtree = fragments_graph.get_kdtree()

# Main
desc = "Filter Doubled Fragment"
desc = "Filter Doubled Fragments"
for idx in tqdm(np.argsort([len(c) for c in components]), desc=desc):
i, j = tuple(components[idx])
swc_id = graph.nodes[i]["swc_id"]
swc_id = fragments_graph.nodes[i]["swc_id"]
if swc_id not in deleted_ids:
if graph.edges[i, j]["length"] < max_length:
if fragments_graph.edges[i, j]["length"] < max_length:
# Check doubles criteria
n_points = len(graph.edges[i, j]["xyz"])
hits = compute_projections(graph, kdtree, (i, j))
n_points = len(fragments_graph.edges[i, j]["xyz"])
hits = compute_projections(fragments_graph, kdtree, (i, j))
if check_doubles_criteria(hits, n_points):
if output_dir:
graph.to_swc(output_dir, [i, j], color=COLOR)
delete_fragment(graph, i, j)
delete_fragment(fragments_graph, i, j)
deleted_ids.add(swc_id)
return len(deleted_ids)


def compute_projections(graph, kdtree, edge):
def compute_projections(fragments_graph, kdtree, edge):
"""
Given a fragment defined by "edge", this routine iterates of every xyz in
the fragment and projects it onto the closest fragment. For each detected
Expand All @@ -93,11 +108,11 @@ def compute_projections(graph, kdtree, edge):
Parameters
----------
graph : graph
fragments_graph : graph
Graph that contains "edge".
kdtree : KDTree
KD-Tree that contains all xyz coordinates of every fragment in
"graph".
"fragments_graph".
edge : tuple
Pair of leaf nodes that define a fragment.
Expand All @@ -109,13 +124,13 @@ def compute_projections(graph, kdtree, edge):
"""
hits = defaultdict(list)
query_id = graph.edges[edge]["swc_id"]
for i, xyz in enumerate(graph.edges[edge]["xyz"]):
query_id = fragments_graph.edges[edge]["swc_id"]
for i, xyz in enumerate(fragments_graph.edges[edge]["xyz"]):
# Compute projections
best_id = None
best_dist = np.inf
for hit_xyz in geometry.query_ball(kdtree, xyz, QUERY_DIST):
hit_id = graph.xyz_to_swc(hit_xyz)
hit_id = fragments_graph.xyz_to_swc(hit_xyz)
if hit_id is not None and hit_id != query_id:
if geometry.dist(hit_xyz, xyz) < best_dist:
best_dist = geometry.dist(hit_xyz, xyz)
Expand Down Expand Up @@ -157,54 +172,54 @@ def check_doubles_criteria(hits, n_points):
return False


def delete_fragment(graph, i, j):
def delete_fragment(fragments_graph, i, j):
"""
Deletes nodes "i" and "j" from "graph", where these nodes form a connected
component.
Deletes nodes "i" and "j" from "fragments_graph", where these nodes form a
connected component.
Parameters
----------
graph : FragmentsGraph
Graph that contains nodes to be deleted.
fragments_graph : FragmentsGraph
Graph that contains nodes to be removed.
i : int
Node to be removed.
j : int
Node to be removed.
Returns
-------
graph
fragments_graph
Graph with nodes removed.
"""
graph = remove_xyz_entries(graph, i, j)
graph.swc_ids.remove(graph.nodes[i]["swc_id"])
graph.remove_nodes_from([i, j])
fragments_graph = remove_xyz_entries(fragments_graph, i, j)
fragments_graph.swc_ids.remove(fragments_graph.nodes[i]["swc_id"])
fragments_graph.remove_nodes_from([i, j])


def remove_xyz_entries(graph, i, j):
def remove_xyz_entries(fragments_graph, i, j):
"""
Removes dictionary entries from "graph.xyz_to_edge" corresponding to
the edge {i, j}.
Removes dictionary entries from "fragments_graph.xyz_to_edge"
corresponding to the edge {i, j}.
Parameters
----------
graph : graph
fragments_graph : graph
Graph to be updated.
i : int
Node in "graph".
Node in graph.
j : int
Node in "graph".
Node in graph.
Returns
-------
graph
Updated graph.
"""
for xyz in graph.edges[i, j]["xyz"]:
del graph.xyz_to_edge[tuple(xyz)]
return graph
for xyz in fragments_graph.edges[i, j]["xyz"]:
del fragments_graph.xyz_to_edge[tuple(xyz)]
return fragments_graph


def upd_hits(hits, key, value):
Expand All @@ -215,8 +230,8 @@ def upd_hits(hits, key, value):
Parameters
----------
hits : dict
Stores swd_ids of fragments that are within a certain distance a query
fragment along with the corresponding distances.
Stores swd_ids of fragments within a certain distance a query fragment
along with the corresponding distances.
key : str
swc id of some fragment.
value : float
Expand All @@ -229,9 +244,30 @@ def upd_hits(hits, key, value):
Updated version of hits.
"""
if key in hits.keys():
if key in hits:
if value < hits[key]:
hits[key] = value
else:
hits[key] = value
return hits


# --- utils ---
def get_line_components(graph):
"""
Identifies and returns all line components in the given graph. A line
component is defined as a connected component with exactly two nodes.
Parameters
----------
graph : networkx.Graph
Input graph in which line components are to be identified.
Returns
-------
List[set]
List of sets, where each set contains two nodes representing a
connected component with exactly two nodes.
"""
return [c for c in nx.connected_components(graph) if len(c) == 2]
2 changes: 1 addition & 1 deletion src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
model_path,
output_dir,
config,
device=None,
device="cpu",
is_multimodal=False,
label_path=None,
log_runtimes=True,
Expand Down

0 comments on commit 249abee

Please sign in to comment.