From 249abee44f7949391a1b64a557394c6173ef609f Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 13 Nov 2024 03:28:29 +0000 Subject: [PATCH] refactor: updated fragment filtering --- src/deep_neurographs/fragment_filtering.py | 160 +++++++++++++-------- src/deep_neurographs/inference.py | 2 +- 2 files changed, 99 insertions(+), 63 deletions(-) diff --git a/src/deep_neurographs/fragment_filtering.py b/src/deep_neurographs/fragment_filtering.py index d77c492..8a735ef 100644 --- a/src/deep_neurographs/fragment_filtering.py +++ b/src/deep_neurographs/fragment_filtering.py @@ -8,83 +8,98 @@ other from a FragmentsGraph. """ + from collections import defaultdict +from tqdm import tqdm +import networkx as nx import numpy as np -from networkx import connected_components -from tqdm import tqdm from deep_neurographs import geometry -from deep_neurographs.utils import util -COLOR = "1.0 0.0 0.0" QUERY_DIST = 15 # --- Curvy Removal --- -def remove_curvy(graph, max_length, ratio=0.5): +def remove_curvy(fragments_graph, max_length, ratio=0.5): + """ + Removes connected components with 2 nodes from "fragments_graph" that are + "curvy" fragments, based on a specified ratio of endpoint distance to edge + length and a maximum length threshold. + + Parameters + ---------- + fragments_graph : FragmentsGraph + Graph generated from fragments of a predicted segmentation. + max_length : float + The maximum allowable length (in microns) for an edge to be considered + for removal. + ratio : float, optional + Threshold ratio of endpoint distance to edge length. Components with a + ratio below this value are considered "curvy" and are removed. The + default is 0.5. + + Returns + ------- + int + Number of fragments removed from the graph. + + """ deleted_ids = set() - components = [c for c in connected_components(graph) if len(c) == 2] - for nodes in tqdm(components, desc="Filter Curvy Fragment"): - if len(nodes) == 2: - i, j = tuple(nodes) - length = graph.edges[i, j]["length"] - endpoint_dist = graph.dist(i, j) - if endpoint_dist / length < ratio and length < max_length: - deleted_ids.add(graph.edges[i, j]["swc_id"]) - delete_fragment(graph, i, j) + components = get_line_components(fragments_graph) + for nodes in tqdm(components, desc="Filter Curvy Fragments"): + i, j = tuple(nodes) + length = fragments_graph.edges[i, j]["length"] + endpoint_dist = fragments_graph.dist(i, j) + if endpoint_dist / length < ratio and length < max_length: + deleted_ids.add(fragments_graph.edges[i, j]["swc_id"]) + delete_fragment(fragments_graph, i, j) return len(deleted_ids) # --- Doubles Removal --- -def remove_doubles(graph, max_length, node_spacing, output_dir=None): +def remove_doubles(fragments_graph, max_length, node_spacing): """ - Removes connected components from "neurgraph" that are likely to be a - double. + Removes connected components from "fragments_graph" that are likely to be + a double -- caused by ghosting in the image. Parameters ---------- - graph : FragmentsGraph + fragments_graph : FragmentsGraph Graph to be searched for doubles. max_length : int Maximum size of connected components to be searched. node_spacing : int - Expected distance in microns between nodes in "graph". - output_dir : str or None, optional - Directory that doubles will be written to. The default is None. + Expected distance (in microns) between nodes in "fragments_graph". Returns ------- - graph - Graph with doubles removed. + int + Number of fragments removed from graph. """ # Initializations - components = [c for c in connected_components(graph) if len(c) == 2] + components = get_line_components(fragments_graph) deleted_ids = set() - kdtree = graph.get_kdtree() - if output_dir: - util.mkdir(output_dir, delete=True) + kdtree = fragments_graph.get_kdtree() # Main - desc = "Filter Doubled Fragment" + desc = "Filter Doubled Fragments" for idx in tqdm(np.argsort([len(c) for c in components]), desc=desc): i, j = tuple(components[idx]) - swc_id = graph.nodes[i]["swc_id"] + swc_id = fragments_graph.nodes[i]["swc_id"] if swc_id not in deleted_ids: - if graph.edges[i, j]["length"] < max_length: + if fragments_graph.edges[i, j]["length"] < max_length: # Check doubles criteria - n_points = len(graph.edges[i, j]["xyz"]) - hits = compute_projections(graph, kdtree, (i, j)) + n_points = len(fragments_graph.edges[i, j]["xyz"]) + hits = compute_projections(fragments_graph, kdtree, (i, j)) if check_doubles_criteria(hits, n_points): - if output_dir: - graph.to_swc(output_dir, [i, j], color=COLOR) - delete_fragment(graph, i, j) + delete_fragment(fragments_graph, i, j) deleted_ids.add(swc_id) return len(deleted_ids) -def compute_projections(graph, kdtree, edge): +def compute_projections(fragments_graph, kdtree, edge): """ Given a fragment defined by "edge", this routine iterates of every xyz in the fragment and projects it onto the closest fragment. For each detected @@ -93,11 +108,11 @@ def compute_projections(graph, kdtree, edge): Parameters ---------- - graph : graph + fragments_graph : graph Graph that contains "edge". kdtree : KDTree KD-Tree that contains all xyz coordinates of every fragment in - "graph". + "fragments_graph". edge : tuple Pair of leaf nodes that define a fragment. @@ -109,13 +124,13 @@ def compute_projections(graph, kdtree, edge): """ hits = defaultdict(list) - query_id = graph.edges[edge]["swc_id"] - for i, xyz in enumerate(graph.edges[edge]["xyz"]): + query_id = fragments_graph.edges[edge]["swc_id"] + for i, xyz in enumerate(fragments_graph.edges[edge]["xyz"]): # Compute projections best_id = None best_dist = np.inf for hit_xyz in geometry.query_ball(kdtree, xyz, QUERY_DIST): - hit_id = graph.xyz_to_swc(hit_xyz) + hit_id = fragments_graph.xyz_to_swc(hit_xyz) if hit_id is not None and hit_id != query_id: if geometry.dist(hit_xyz, xyz) < best_dist: best_dist = geometry.dist(hit_xyz, xyz) @@ -157,15 +172,15 @@ def check_doubles_criteria(hits, n_points): return False -def delete_fragment(graph, i, j): +def delete_fragment(fragments_graph, i, j): """ - Deletes nodes "i" and "j" from "graph", where these nodes form a connected - component. + Deletes nodes "i" and "j" from "fragments_graph", where these nodes form a + connected component. Parameters ---------- - graph : FragmentsGraph - Graph that contains nodes to be deleted. + fragments_graph : FragmentsGraph + Graph that contains nodes to be removed. i : int Node to be removed. j : int @@ -173,28 +188,28 @@ def delete_fragment(graph, i, j): Returns ------- - graph + fragments_graph Graph with nodes removed. """ - graph = remove_xyz_entries(graph, i, j) - graph.swc_ids.remove(graph.nodes[i]["swc_id"]) - graph.remove_nodes_from([i, j]) + fragments_graph = remove_xyz_entries(fragments_graph, i, j) + fragments_graph.swc_ids.remove(fragments_graph.nodes[i]["swc_id"]) + fragments_graph.remove_nodes_from([i, j]) -def remove_xyz_entries(graph, i, j): +def remove_xyz_entries(fragments_graph, i, j): """ - Removes dictionary entries from "graph.xyz_to_edge" corresponding to - the edge {i, j}. + Removes dictionary entries from "fragments_graph.xyz_to_edge" + corresponding to the edge {i, j}. Parameters ---------- - graph : graph + fragments_graph : graph Graph to be updated. i : int - Node in "graph". + Node in graph. j : int - Node in "graph". + Node in graph. Returns ------- @@ -202,9 +217,9 @@ def remove_xyz_entries(graph, i, j): Updated graph. """ - for xyz in graph.edges[i, j]["xyz"]: - del graph.xyz_to_edge[tuple(xyz)] - return graph + for xyz in fragments_graph.edges[i, j]["xyz"]: + del fragments_graph.xyz_to_edge[tuple(xyz)] + return fragments_graph def upd_hits(hits, key, value): @@ -215,8 +230,8 @@ def upd_hits(hits, key, value): Parameters ---------- hits : dict - Stores swd_ids of fragments that are within a certain distance a query - fragment along with the corresponding distances. + Stores swd_ids of fragments within a certain distance a query fragment + along with the corresponding distances. key : str swc id of some fragment. value : float @@ -229,9 +244,30 @@ def upd_hits(hits, key, value): Updated version of hits. """ - if key in hits.keys(): + if key in hits: if value < hits[key]: hits[key] = value else: hits[key] = value return hits + + +# --- utils --- +def get_line_components(graph): + """ + Identifies and returns all line components in the given graph. A line + component is defined as a connected component with exactly two nodes. + + Parameters + ---------- + graph : networkx.Graph + Input graph in which line components are to be identified. + + Returns + ------- + List[set] + List of sets, where each set contains two nodes representing a + connected component with exactly two nodes. + + """ + return [c for c in nx.connected_components(graph) if len(c) == 2] diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index b34e49e..2d0537d 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -66,7 +66,7 @@ def __init__( model_path, output_dir, config, - device=None, + device="cpu", is_multimodal=False, label_path=None, log_runtimes=True,