diff --git a/snudda/place/bend_morphologies.py b/snudda/place/bend_morphologies.py index 4b4a69a02..c81144458 100644 --- a/snudda/place/bend_morphologies.py +++ b/snudda/place/bend_morphologies.py @@ -10,6 +10,9 @@ class BendMorphologies: def __init__(self, region_mesh: RegionMeshRedux, rng): + if type(region_mesh) == str: + region_mesh = RegionMeshRedux(mesh_path=region_mesh) + self.region_mesh = region_mesh self.rng = rng @@ -20,7 +23,7 @@ def check_if_inside(self, morphology: NeuronMorphologyExtended): return inside_flag - def bend_morphology(self, morphology: NeuronMorphologyExtended, k=50e-6): + def bend_morphology_OLD(self, morphology: NeuronMorphologyExtended, k=50e-6): # TODO: Parent point idx is included if parent section is of same type as section! # So we should not rotate the first point! @@ -49,7 +52,7 @@ def bend_morphology(self, morphology: NeuronMorphologyExtended, k=50e-6): coords = section.morphology_data.geometry[point_idx, :] dist = self.region_mesh.distance_to_border(points=coords) - P = 1 / (1 + np.exp(-k * dist)) + P = 1 / (1 + np.exp(-dist/k)) if self.rng.uniform(1) < P: # We need to randomize new rotation matrix @@ -62,7 +65,7 @@ def bend_morphology(self, morphology: NeuronMorphologyExtended, k=50e-6): candidate_dist = self.region_mesh.distance_to_border(points=candidate_pos) - P_candidate = np.divide(1, 1 + np.exp(-k * candidate_dist)) + P_candidate = np.divide(1, 1 + np.exp(-candidate_dist/k)) picked_idx = self.rng.choice(n_random, p=P_candidate) new_coords = candidate_pos[picked_idx, :] @@ -92,6 +95,81 @@ def bend_morphology(self, morphology: NeuronMorphologyExtended, k=50e-6): parent_rotation_matrices[point_idx] = rotation_matrix + def bend_morphology(self, morphology: NeuronMorphologyExtended, k=20e-6, n_random=10): + + # k -- how early will the neuron start bending when it approaches the border + + candidate_pos = np.zeros((n_random, 3)) + + parent_direction = dict() + + old_rotation_representation = self.get_full_rotation_representation(morphology=morphology) + new_rotation_representation = dict() + + for section in morphology.section_iterator(): + if (section.section_id, section.section_type) in parent_direction: + parent_dir, parent_point = parent_direction[section.section_id, section.section_type] + else: + if morphology.rotation is not None: + parent_dir = np.matmul(morphology.rotation, np.array([[1], [0], [0]])).T + else: + parent_dir = np.array([[1, 0, 0]]) + + if morphology.position is not None: + parent_point = morphology.position + else: + parent_point = np.zeros((3, )) + + rot_rep = old_rotation_representation[section.section_id, section.section_type] + new_rot_rep = [] + + # Loop over all points in section + for idx, (rotation, length) in enumerate(rot_rep): + + try: + segment_direction = rotation.apply(parent_dir) + putative_point = segment_direction * length + parent_point + except: + import traceback + print(traceback.format_exc()) + import pdb + pdb.set_trace() + + # Check if point is too close to edge + dist = self.region_mesh.distance_to_border(points=putative_point) + P_keep = 1 / (1 + np.exp(-dist/k)) + + if self.rng.uniform(1) < P_keep: + # We need to randomize new rotation matrix + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.html + angles = self.rng.uniform(size=(n_random, 3), low=-0.2, high=0.2) # Angles in radians + avoidance_rotation = Rotation.from_euler(seq="XYZ", angles=angles) + + for idx, av_rot in enumerate(avoidance_rotation): + + candidate_pos[idx, :] = parent_point + length * (av_rot*rotation).apply(vectors=parent_direction) + + candidate_dist = self.region_mesh.distance_to_border(points=candidate_pos) + + P_candidate = np.divide(1, 1 + np.exp(-candidate_dist/k)) + picked_idx = self.rng.choice(n_random, p=P_candidate) + + new_rot = avoidance_rotation[picked_idx] * rotation + new_rot_rep.append((new_rot, length)) + segment_direction = new_rot.apply(parent_direction) + else: + new_rot_rep.append((rotation, length)) + + parent_point = segment_direction * length + parent_point + parent_dir = segment_direction + + for child_id, child_type in section.child_section_id.T: + parent_direction[child_id, child_type] = (parent_dir, parent_point) + + new_rotation_representation[section.section_id, section.section_type] = new_rot_rep + + return new_rotation_representation + def get_full_rotation_representation(self, morphology: MorphologyData): rotation_representation = dict() @@ -130,12 +208,12 @@ def apply_rotation(self, morphology: MorphologyData, rotation_representation): if (section.section_id, section.section_type) in parent_direction: parent_dir, parent_pos = parent_direction[section.section_id, section.section_type] else: - if morphology.rotation: - parent_dir = np.matmul(morphology.rotation, np.array([[1, 0, 0]])) + if morphology.rotation is not None: + parent_dir = np.matmul(morphology.rotation, np.array([[1, 0, 0]]).T).T else: parent_dir = np.array([[1, 0, 0]]) - if morphology.position: + if morphology.position is not None: parent_pos = morphology.position else: parent_pos = np.zeros((3, )) @@ -257,9 +335,30 @@ def test_rotation_representation(): print(f"Geometry matches") +def test_bending(): + + file_path = "../data/neurons/striatum/dspn/str-dspn-e150602_c1_D1-mWT-0728MSN01-v20190508/WT-0728MSN01-cor-rep-ax.swc" + mesh_path = "../data/mesh/Striatum-d-right.obj" + + md = MorphologyData(swc_file=file_path) + bm = BendMorphologies(mesh_path, rng=np.random.default_rng()) + + before = md.clone(position=np.array([7300, 4000, -0.8])*1e-6, rotation=np.eye(3)) + after = md.clone(position=np.array([7300, 4000, -0.8])*1e-6, rotation=np.eye(3)) + + new_rot_rep = bm.bend_morphology(after) + new_coord = bm.apply_rotation(after, new_rot_rep) + after.geometry[:, :3] = new_coord + + before.plot() + after.plot() + + if __name__ == "__main__": - test_rotation_representation() + # test_rotation_representation() + + test_bending() import pdb pdb.set_trace() \ No newline at end of file