Skip to content

Commit

Permalink
refactor: renamed neurograph to fragments graph
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Nov 13, 2024
1 parent 249abee commit f2344b6
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 91 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/fragment_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
"""

from collections import defaultdict
from tqdm import tqdm

import networkx as nx
import numpy as np
from tqdm import tqdm

from deep_neurographs import geometry

Expand Down
File renamed without changes.
195 changes: 106 additions & 89 deletions src/deep_neurographs/generate_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@


def run(
graph,
fragments_graph,
radius,
complex_bool=False,
long_range_bool=True,
progress_bar=True,
trim_endpoints_bool=True,
):
"""
Generates proposals emanating from "leaf".
Generates proposals for fragments graph.
Parameters
----------
graph : FragmentsGraph
fragments_graph : FragmentsGraph
Graph that proposals will be generated for.
radius : float
Maximum Euclidean distance between endpoints of proposal.
Expand All @@ -58,29 +58,29 @@ def run(
"""
# Initializations
connections = dict()
kdtree = init_kdtree(graph, complex_bool)
kdtree = init_kdtree(fragments_graph, complex_bool)
radius *= RADIUS_SCALING_FACTOR if trim_endpoints_bool else 1.0
if progress_bar:
iterable = tqdm(graph.get_leafs(), desc="Proposals")
iterable = tqdm(fragments_graph.get_leafs(), desc="Proposals")
else:
iterable = graph.get_leafs()
iterable = fragments_graph.get_leafs()

# Main
for leaf in iterable:
# Generate potential proposals
candidates = get_candidates(
graph,
fragments_graph,
leaf,
kdtree,
radius,
graph.proposals_per_leaf,
fragments_graph.proposals_per_leaf,
complex_bool,
)

# Generate long range proposals (if applicable)
if len(candidates) == 0 and long_range_bool:
candidates = get_candidates(
graph,
fragments_graph,
leaf,
kdtree,
radius * RADIUS_SCALING_FACTOR,
Expand All @@ -90,33 +90,34 @@ def run(

# Determine which potential proposals to keep
for i in candidates:
leaf_swc_id = graph.nodes[leaf]["swc_id"]
pair_id = frozenset((leaf_swc_id, graph.nodes[i]["swc_id"]))
leaf_swc_id = fragments_graph.nodes[leaf]["swc_id"]
node_swc_id = fragments_graph.nodes[i]["swc_id"]
pair_id = frozenset((leaf_swc_id, node_swc_id))
if pair_id in connections.keys():
cur_proposal = connections[pair_id]
cur_dist = graph.proposal_length(cur_proposal)
if graph.dist(leaf, i) < cur_dist:
graph.remove_proposal(cur_proposal)
cur_dist = fragments_graph.proposal_length(cur_proposal)
if fragments_graph.dist(leaf, i) < cur_dist:
fragments_graph.remove_proposal(cur_proposal)
del connections[pair_id]
else:
continue

# Add proposal
graph.add_proposal(leaf, i)
fragments_graph.add_proposal(leaf, i)
connections[pair_id] = frozenset({leaf, i})

# Trim endpoints (if applicable)
n_trimmed = 0
if trim_endpoints_bool:
radius /= RADIUS_SCALING_FACTOR
long_range, in_range = separate_proposals(graph, radius)
graph, n_trimmed_1 = run_trimming(graph, long_range, radius)
graph, n_trimmed_2 = run_trimming(graph, in_range, radius)
n_trimmed = n_trimmed_1 + n_trimmed_2
long_range, in_range = partition_proposals(fragments_graph, radius)
cnt_1 = run_trimming(fragments_graph, long_range, radius)
cnt_2 = run_trimming(fragments_graph, in_range, radius)
n_trimmed = cnt_1 + cnt_2
return n_trimmed


def init_kdtree(graph, complex_bool):
def init_kdtree(fragments_graph, complex_bool):
"""
Initializes a KD-Tree used to generate proposals.
Expand All @@ -130,13 +131,14 @@ def init_kdtree(graph, complex_bool):
Returns
-------
scipy.spatial.cKDTree
kdtree.
kdtree built from all xyz coordinates across edges in graph if
complex_bool is True; otherwise, only built from leaf nodes.
"""
if complex_bool:
return graph.get_kdtree()
return fragments_graph.get_kdtree()
else:
return graph.get_kdtree(node_type="leaf")
return fragments_graph.get_kdtree(node_type="leaf")


def get_candidates(
Expand All @@ -157,16 +159,16 @@ def get_candidates(
return list() if max_proposals < 0 else candidates


def search_kdtree(graph, leaf, kdtree, radius, max_proposals):
def search_kdtree(fragments_graph, leaf, kdtree, radius, max_proposals):
"""
Generates proposals for node "leaf" in "graph" by finding candidate
xyz points on distinct connected components nearby.
Generates proposals emanating from node "leaf" by finding candidate xyz
points on distinct connected components nearby.
Parameters
----------
graph : FragmentsGraph
Graph built from swc files.
kdtree : ...
fragments_graph : FragmentsGraph
Graph that proposals will be generated for.
kdtree : scipy.spatial.cKDTree
...
leaf : int
Leaf node that proposals are to be generated from.
Expand All @@ -183,10 +185,10 @@ def search_kdtree(graph, leaf, kdtree, radius, max_proposals):
"""
# Generate candidates
candidates = dict()
leaf_xyz = graph.nodes[leaf]["xyz"]
leaf_xyz = fragments_graph.nodes[leaf]["xyz"]
for xyz in geometry.query_ball(kdtree, leaf_xyz, radius):
swc_id = graph.xyz_to_swc(xyz)
if swc_id != graph.nodes[leaf]["swc_id"]:
swc_id = fragments_graph.xyz_to_swc(xyz)
if swc_id != fragments_graph.nodes[leaf]["swc_id"]:
d = geometry.dist(leaf_xyz, xyz)
if swc_id not in candidates.keys():
candidates[swc_id] = {"dist": d, "xyz": tuple(xyz)}
Expand Down Expand Up @@ -231,13 +233,13 @@ def get_best(candidates, max_proposals):
return list_candidates_xyz(candidates)


def get_connecting_node(graph, leaf, xyz, radius, complex_bool):
def get_connecting_node(fragments_graph, leaf, xyz, radius, complex_bool):
"""
Gets the node that proposal with leaf will connect to.
Gets node that proposal emanating from "leaf" will connect to.
Parameters
----------
graph : FragmentsGraph
fragments_graph : FragmentsGraph
Graph containing "leaf".
leaf : int
Leaf node.
Expand All @@ -247,28 +249,28 @@ def get_connecting_node(graph, leaf, xyz, radius, complex_bool):
Returns
-------
int
Node id.
Node id that proposal will connect to.
"""
edge = graph.xyz_to_edge[xyz]
node = get_closer_endpoint(graph, edge, xyz)
if graph.dist(leaf, node) < radius:
edge = fragments_graph.xyz_to_edge[xyz]
node = get_closer_endpoint(fragments_graph, edge, xyz)
if fragments_graph.dist(leaf, node) < radius:
return node
elif complex_bool:
attrs = graph.get_edge_data(*edge)
attrs = fragments_graph.get_edge_data(*edge)
idx = np.where(np.all(attrs["xyz"] == xyz, axis=1))[0][0]
if type(idx) is int:
return graph.split_edge(edge, attrs, idx)
return fragments_graph.split_edge(edge, attrs, idx)
return None


def get_closer_endpoint(graph, edge, xyz):
def get_closer_endpoint(fragments_graph, edge, xyz):
"""
Gets the node from "edge" that is closer to "xyz".
Gets node from "edge" that is closer to "xyz".
Parameters
----------
graph : FragmentsGraph
fragments_graph : FragmentsGraph
Graph containing "edge".
edge : tuple
Edge to be checked.
Expand All @@ -277,51 +279,66 @@ def get_closer_endpoint(graph, edge, xyz):
Returns
-------
tuple
Node id and its distance from "xyz".
int
Node closer to "xyz".
"""
i, j = tuple(edge)
d_i = geometry.dist(graph.nodes[i]["xyz"], xyz)
d_j = geometry.dist(graph.nodes[j]["xyz"], xyz)
d_i = geometry.dist(fragments_graph.nodes[i]["xyz"], xyz)
d_j = geometry.dist(fragments_graph.nodes[j]["xyz"], xyz)
return i if d_i < d_j else j


def separate_proposals(graph, radius):
def partition_proposals(fragments_graph, radius):
"""
Partitions proposals in "fragments_graph" into long-range and in-range
categories based on a specified length threshold.
Parameters
----------
fragments_graph : FragmentsGraph
Graph with proposals to be partitioned.
radius : float
Length threshold used to partition proposals. Proposals with length
greater than "radius" are said to be long-range; otherwise, in-range.
Returns
-------
list, list
Lists of long-range and in-range proposals.
"""
long_range_proposals = list()
proposals = list()
for proposal in graph.proposals:
i, j = tuple(proposal)
if graph.dist(i, j) > radius:
long_range_proposals.append(proposal)
in_range_proposals = list()
for p in fragments_graph.proposals:
if fragments_graph.proposal_length(p) > radius:
long_range_proposals.append(p)
else:
proposals.append(proposal)
return long_range_proposals, proposals
in_range_proposals.append(p)
return long_range_proposals, in_range_proposals


# --- Trim Endpoints ---
def run_trimming(graph, proposals, radius):
n_endpoints_trimmed = 0
def run_trimming(fragments_graph, proposals, radius):
n_trimmed = 0
long_radius = radius * RADIUS_SCALING_FACTOR
for proposal in deepcopy(proposals):
i, j = tuple(proposal)
is_simple = graph.is_simple(proposal)
is_single = graph.is_single_proposal(proposal)
for p in deepcopy(proposals):
is_simple = fragments_graph.is_simple(p)
is_single = fragments_graph.is_single_proposal(p)
trim_bool = False
if is_simple and is_single:
graph, trim_bool = trim_endpoints(
graph, i, j, long_radius
)
elif graph.dist(i, j) > radius:
graph.remove_proposal(proposal)
n_endpoints_trimmed += 1 if trim_bool else 0
return graph, n_endpoints_trimmed
trim_bool = trim_endpoints(fragments_graph, p, long_radius)
elif fragments_graph.proposal_length(p) > radius:
fragments_graph.remove_proposal(p)
n_trimmed += 1 if trim_bool else 0
return n_trimmed


def trim_endpoints(graph, i, j, radius):
def trim_endpoints(fragments_graph, proposal, radius):
# Initializations
branch_i = graph.branch(i)
branch_j = graph.branch(j)
i, j = tuple(proposal)
branch_i = fragments_graph.branch(i)
branch_j = fragments_graph.branch(j)

# Check both orderings
idx_i, idx_j = trim_endpoints_ordered(branch_i, branch_j)
Expand All @@ -334,14 +351,14 @@ def trim_endpoints(graph, i, j, radius):

# Update branches (if applicable)
if min(d1, d2) > radius:
graph.remove_proposal(frozenset((i, j)))
return graph, False
fragments_graph.remove_proposal(frozenset((i, j)))
return False
elif min(d1, d2) + 2 < geometry.dist(branch_i[0], branch_j[0]):
if compute_dot(branch_i, branch_j, idx_i, idx_j) < DOT_THRESHOLD:
graph = trim_to_idx(graph, i, idx_i)
graph = trim_to_idx(graph, j, idx_j)
return graph, True
return graph, False
fragments_graph = trim_to_idx(fragments_graph, i, idx_i)
fragments_graph = trim_to_idx(fragments_graph, j, idx_j)
return True
return False


def trim_endpoints_ordered(branch_1, branch_2):
Expand Down Expand Up @@ -376,13 +393,13 @@ def trim_endpoint(branch_1, branch_2):
return 0 if best_idx is None else best_idx


def trim_to_idx(graph, i, idx):
def trim_to_idx(fragments_graph, i, idx):
"""
Trims the branch emanating from "i".
Parameters
----------
graph : FragmentsGraph
fragments_graph : FragmentsGraph
Graph containing node "i"
i : int
Leaf node.
Expand All @@ -395,21 +412,21 @@ def trim_to_idx(graph, i, idx):
"""
# Update node
branch_xyz = graph.branch(i, key="xyz")
branch_radii = graph.branch(i, key="radius")
graph.nodes[i]["xyz"] = branch_xyz[idx]
graph.nodes[i]["radius"] = branch_radii[idx]
branch_xyz = fragments_graph.branch(i, key="xyz")
branch_radii = fragments_graph.branch(i, key="radius")
fragments_graph.nodes[i]["xyz"] = branch_xyz[idx]
fragments_graph.nodes[i]["radius"] = branch_radii[idx]

# Update edge
j = graph.leaf_neighbor(i)
graph.edges[i, j]["xyz"] = branch_xyz[idx::]
graph.edges[i, j]["radius"] = branch_radii[idx::]
j = fragments_graph.leaf_neighbor(i)
fragments_graph.edges[i, j]["xyz"] = branch_xyz[idx::]
fragments_graph.edges[i, j]["radius"] = branch_radii[idx::]
for k in range(idx):
try:
del graph.xyz_to_edge[tuple(branch_xyz[k])]
del fragments_graph.xyz_to_edge[tuple(branch_xyz[k])]
except KeyError:
pass
return graph
return fragments_graph


# --- utils ---
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def run(
FragmentsGraph generated from swc files.
"""
from deep_neurographs.neurograph import FragmentsGraph
from deep_neurographs.fragments_graph import FragmentsGraph

# Load fragments and extract irreducibles
self.img_bbox = img_util.init_bbox(img_patch_origin, img_patch_shape)
Expand Down

0 comments on commit f2344b6

Please sign in to comment.