Skip to content

Commit

Permalink
update streamline registration
Browse files Browse the repository at this point in the history
  • Loading branch information
skoudoro committed Aug 21, 2023
1 parent 81afff5 commit e49e499
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 22 deletions.
4 changes: 2 additions & 2 deletions quantconn/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def merge(destination: Annotated[Path, typer.Option("--destination", "-dest",
icc_conn = pg.intraclass_corr(data=df_tmp, targets='# subject',
raters='group', ratings='score')
icc_conn.set_index('Type')
results_conn.append(float(icc_conn.loc[icc_conn['Type'] == 'ICC1',
results_conn.append(float(icc_conn.loc[icc_conn['Type'] == 'ICC3',
'ICC']))

print(f"Connectivity all scores : {results_conn}")
Expand All @@ -258,7 +258,7 @@ def merge(destination: Annotated[Path, typer.Option("--destination", "-dest",

icc_mm.set_index('Type')

results_mm.append(float(icc_mm.loc[icc_mm['Type'] == 'ICC1', 'ICC']))
results_mm.append(float(icc_mm.loc[icc_mm['Type'] == 'ICC3', 'ICC']))

with open(pjoin(destination, '_bundle_metrics_icc_report.csv'), 'w') as fh:
writer = csv.writer(fh, delimiter=',')
Expand Down
135 changes: 115 additions & 20 deletions quantconn/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,38 @@
from dipy.tracking.metrics import length
from dipy.tracking.local_tracking import LocalTracking
from dipy.tracking.streamline import Streamlines, transform_streamlines
from dipy.tracking.streamlinespeed import length
from dipy.segment.mask import median_otsu
from dipy.segment.bundles import RecoBundles

from dipy.reconst.shm import normalize_data, sph_harm_lookup, smooth_pinv
from dipy.core.sphere import HemiSphere
from dipy.core.gradients import gradient_table_from_bvals_bvecs
from dipy.reconst.shm import anisotropic_power

from quantconn.download import get_30_bundles_atlas_hcp842


def signal_powermap(data, gtab, sh_order=8, smooth=0.0):
gtab2 = gradient_table_from_bvals_bvecs(gtab.bvals[np.where(1-gtab.b0s_mask)[0]], gtab.bvecs[np.where(1-gtab.b0s_mask)[0]])
normed_data = normalize_data(data, gtab.b0s_mask)
normed_data = normed_data[..., np.where(1-gtab.b0s_mask)[0]]

signal_native_pts = HemiSphere(xyz=gtab2.bvecs)
sph_harm_basis = sph_harm_lookup.get(None)

Ba, m, n = sph_harm_basis(sh_order, signal_native_pts.theta,
signal_native_pts.phi)
L = -n * (n + 1)
invB = smooth_pinv(Ba, np.sqrt(smooth) * L)

# fit SH basis to DWI signal
normed_data_sh = np.dot(normed_data, invB.T)
ap_map_signal = anisotropic_power(normed_data_sh)

return ap_map_signal


def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path,
t1_labels_fname=None, group='B'):
dwi_data, dwi_affine, dwi_img = load_nifti(nifti_fname, return_img=True)
Expand All @@ -45,13 +70,15 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path,

print(':left_arrow_curving_right: Sampling/reslicing data')
vox_sz = dwi_img.header.get_zooms()[:3]
# Voxel size depends on A and B case
new_vox_size = [2.2, 2.2, 2.2]
if group.lower() == 'b':
new_vox_size = [1.9, 1.9, 1.9]

# Dynamic resampling
vox_factor = 0.14
voxsize_sorted = sorted(vox_sz)
max_vox_size, smax_vox_size = voxsize_sorted[-1], voxsize_sorted[-2]
new_vox_size = [smax_vox_size + (max_vox_size - smax_vox_size) * vox_factor ] * 3
# TODO: Check reslice order. Try with 2 and compare data (trilinear vs cubic)
resliced_data, resliced_affine = reslice(dwi_data, dwi_affine, vox_sz,
new_vox_size, order=1)
new_vox_size, order=2)

save_nifti(pjoin(output_path, 'resliced_data.nii.gz'),
resliced_data, resliced_affine)
Expand All @@ -60,6 +87,11 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path,
maskdata, mask = median_otsu(
resliced_data, vol_idx=np.where(gtab.b0s_mask)[0][:2])

# Power map
powermap_data = signal_powermap(maskdata, gtab)
save_nifti(pjoin(output_path, 'powermap_data.nii.gz'),
powermap_data, resliced_affine)

print(':left_arrow_curving_right: Computing DTI metrics')
tenmodel = TensorModel(gtab)
tenfit = tenmodel.fit(maskdata)
Expand Down Expand Up @@ -238,29 +270,49 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path,
t1_noskull_data, t1_noskull_affine, t1_noskull_img = \
load_nifti(t1_skullstrip_fname, return_img=True)

t1_vox_sz = t1_noskull_img.header.get_zooms()[:3]
t1_noskull_resliced_data, \
t1_noskull_resliced_affine = reslice(t1_noskull_data,
t1_noskull_affine,
t1_vox_sz, new_vox_size, order=2)

save_nifti(pjoin(output_path, 't1_noskull_resliced.nii.gz'),
t1_noskull_resliced_data, t1_noskull_resliced_affine)

# test_image_registration(maskdata, gtab, t1_noskull_data,
# t1_noskull_resliced_data, powermap_data,
# resliced_affine, t1_noskull_affine,
# t1_noskull_resliced_affine,
# output_path)

print(':left_arrow_curving_right: Connectivity matrix: Registering DWI B0s to T1 / labels')
pipeline = ["center_of_mass", "translation", "rigid", "rigid_isoscaling", "rigid_scaling"]
# Take one B0 instead of all of them or correct motion.
mean_b0 = np.mean(maskdata[..., gtab.b0s_mask], -1)
warped_b0, warped_b0_affine = affine_registration(
mean_b0, t1_noskull_data, moving_affine=resliced_affine,
static_affine=t1_noskull_affine)
mean_b0, t1_noskull_resliced_data, moving_affine=resliced_affine,
static_affine=t1_noskull_resliced_affine, pipeline=pipeline)

save_nifti(pjoin(output_path, "warped_b0.nii.gz"), warped_b0,
t1_noskull_affine)
save_nifti(pjoin(output_path, "warped_b0_resliced.nii.gz"),
warped_b0, t1_noskull_resliced_affine)

print(':left_arrow_curving_right: Connectivity matrix: Transforming Streamlines')
target_streamlines_in_t1 = transform_streamlines(target_streamlines,
warped_b0_affine,
in_place=True)
target_streamlines_in_t1 = transform_streamlines(
target_streamlines, np.linalg.inv(warped_b0_affine)) # in_place=True)

t1_noskull_resliced_data, t1_noskull_resliced_affine, t1_noskull_resliced_img = \
load_nifti(pjoin(output_path, 't1_noskull_resliced.nii.gz'), return_img=True)

# if 1:
# print(nb.aff2axcodes(mapping))
header = create_tractogram_header(
TrkFile, warped_b0_affine, maskdata.shape[:3], new_vox_size,
''.join(nib.aff2axcodes(warped_b0_affine)))
target_streamlines_in_resliced_t1_sft = StatefulTractogram(
target_streamlines_in_t1, t1_noskull_resliced_img, Space.RASMM)

save_trk(target_streamlines_in_resliced_t1_sft,
pjoin(output_path, "full_tractogram_in_resliced_t1.trk"),
bbox_valid_check=False)

target_streamlines_in_t1_sft = StatefulTractogram(target_streamlines_in_t1,
header, Space.RASMM)
t1_img, Space.RASMM)

save_trk(target_streamlines_in_t1_sft,
pjoin(output_path, "full_tractogram_in_t1.trk"),
bbox_valid_check=False)
Expand All @@ -271,7 +323,6 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path,
# horizon(tractograms=[target_streamlines_in_t1_sft],
# images=[(label_data, label_affine)], interactive=True,
# cluster=True, world_coords=True)
# import ipdb; ipdb.set_trace()

print(':left_arrow_curving_right: Connectivity matrix')
# # Connectivity matrix
Expand All @@ -287,7 +338,51 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path,
# plt.savefig(pjoin(output_path, "connectivity.png"))



def test_image_registration(maskdata, gtab, t1_noskull_data,
t1_noskull_resliced_data, powermap_data,
resliced_affine, t1_noskull_affine,
t1_noskull_resliced_affine,
output_path):

pipeline_1 = ["center_of_mass", "translation", "rigid"]
pipeline_2 = ["rigid_isoscaling"]
pipeline_3 = ["rigid_isoscaling", "rigid"]
pipeline_4 = ["center_of_mass", "rigid_isoscaling"]
pipeline_5 = pipeline_1 + ["rigid_isoscaling"]
pipeline_6 = pipeline_5 + ["rigid_scaling"]
all_pipelines = [pipeline_1, pipeline_2, pipeline_3, pipeline_4,
pipeline_5, pipeline_6]
print(':left_arrow_curving_right: Connectivity matrix: Registering DWI B0s to T1 / labels')
for p_idx, pipeline in enumerate(all_pipelines, start=1):
# Take one B0 instead of all of them or correct motion.
mean_b0 = np.mean(maskdata[..., gtab.b0s_mask], -1)
warped_b0, warped_b0_affine = affine_registration(
mean_b0, t1_noskull_data, moving_affine=resliced_affine,
static_affine=t1_noskull_affine, pipeline=pipeline)

warped_b0_iso, warped_b0_iso_affine = affine_registration(
mean_b0, t1_noskull_resliced_data, moving_affine=resliced_affine,
static_affine=t1_noskull_resliced_affine, pipeline=pipeline)

save_nifti(pjoin(output_path, f"warped_b0_{p_idx}.nii.gz"), warped_b0,
t1_noskull_affine)
save_nifti(pjoin(output_path, f"warped_b0_{p_idx}_resliced.nii.gz"),
warped_b0_iso, t1_noskull_resliced_affine)

# we use the powermap instead of the mean b0

warped_pm, warped_pm_affine = affine_registration(
powermap_data, t1_noskull_data, moving_affine=resliced_affine,
static_affine=t1_noskull_affine, pipeline=pipeline)
warped_pm_iso, warped_pm_iso_affine = affine_registration(
powermap_data, t1_noskull_resliced_data,
moving_affine=resliced_affine,
static_affine=t1_noskull_resliced_affine, pipeline=pipeline)

save_nifti(pjoin(output_path, f"warped_pm_{p_idx}.nii.gz"),
warped_pm, t1_noskull_affine)
save_nifti(pjoin(output_path, f"warped_pm_{p_idx}_resliced.nii.gz"),
warped_pm_iso, t1_noskull_resliced_affine)


# Process 1: Bundles/Tractography/BUAN
Expand Down

0 comments on commit e49e499

Please sign in to comment.