Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add label iamge registration and tests. #743

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ants/registration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
from .integrate_velocity_field import integrate_velocity_field
from .invert_displacement_field import invert_displacement_field
from .landmark_transforms import fit_transform_to_paired_points, fit_time_varying_transform_to_point_sets
from .registration import registration, motion_correction
from .simulate_displacement_field import simulate_displacement_field
from .registration import registration, motion_correction, label_image_registration
from .simulate_displacement_field import simulate_displacement_field
326 changes: 324 additions & 2 deletions ants/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
ANTsPy Registration
"""
__all__ = ["registration",
"motion_correction"]
"motion_correction",
"label_image_registration"]

import os
import numpy as np
from tempfile import mktemp
import glob
Expand Down Expand Up @@ -1565,3 +1565,325 @@ def motion_correction(
"motion_parameters": motion_parameters,
"FD": FD,
}

def label_image_registration(fixed_label_images,
moving_label_images,
fixed_intensity_images=None,
moving_intensity_images=None,
fixed_mask=None,
moving_mask=None,
type_of_linear_transform='affine',
type_of_transform='antsRegistrationSyNQuick[so]',
label_image_weighting=1.0,
output_prefix='',
random_seed=None,
verbose=False):

"""
Perform pairwise registration using fixed and moving sets of label
images (and, optionally, sets of corresponding intensity images).

Arguments
---------
fixed_label_images : single or list of ANTsImage
A single (or set of) fixed label image(s).

moving_label_images : single or list of ANTsImage
A single (or set of) moving label image(s).

fixed_intensity_images : single or list of ANTsImage
Optional---a single (or set of) fixed intensity image(s).

moving_intensity_images : single or list of ANTsImage
Optional---a single (or set of) moving intensity image(s).

fixed_mask : ANTsImage
Defines region for similarity metric calculation in the space
of the fixed image.

moving_mask : ANTsImage
Defines region for similarity metric calculation in the space
of the moving image.

type_of_linear_transform : string
Use label images with the centers of mass to a calculate linear
transform of type 'rigid', 'similarity', or 'affine'.

type_of_transform : string
Only works with deformable-only transforms, specifically the family
of antsRegistrationSyN*[so] or antsRegistrationSyN*[bo] transforms.
See 'type_of_transform' in ants.registration.

label_image_weighting : float or list of floats
Relative weighting for the label images.

output_prefix : string
Define the output prefix for the filenames of the output transform
files.

verbose : boolean
Print progress to the screen.

Returns
-------
Set of transforms definining the mapping to/from the fixed image domain
to the moving image domain.

Example
-------
>>>
>>>
"""

# Perform validation check on the input

if isinstance(fixed_label_images, ants.ANTsImage):
fixed_label_images = [ants.image_clone(fixed_label_images)]
if isinstance(moving_label_images, ants.ANTsImage):
moving_label_images = [ants.image_clone(moving_label_images)]

if len(fixed_label_images) != len(moving_label_images):
raise ValueError("The number of fixed and moving label images do not match.")

if fixed_intensity_images is not None or moving_intensity_images is not None:
if isinstance(fixed_intensity_images, ants.ANTsImage):
fixed_intensity_images = [ants.image_clone(fixed_intensity_images)]
if isinstance(moving_intensity_images, ants.ANTsImage):
moving_intensity_images = [ants.image_clone(moving_intensity_images)]
if len(fixed_intensity_images) != len(moving_intensity_images):
raise ValueError("The number of fixed and moving intensity images do not match.")

label_image_weights = list()
if isinstance(label_image_weighting, (int, float)):
label_image_weights = [label_image_weighting] * len(fixed_label_images)
else:
label_image_weights = tuple(label_image_weighting)
if len(fixed_label_images) != len(label_image_weights):
raise ValueError("The length of label_image_weights must" +
"match the number of label image pairs.")

image_dimension = fixed_label_images[0].dimension

if output_prefix == "" or output_prefix is None or len(output_prefix) == 0:
output_prefix = mktemp()

allowable_linear_transforms = ['rigid', 'similarity', 'affine']
if not type_of_linear_transform in allowable_linear_transforms:
raise ValueError("Unrecognized linear transform.")

do_deformable = False
if type_of_transform is not None or len(type_of_transform) > 0:
do_deformable = True

common_label_ids = list()
total_number_of_labels = 0
for i in range(len(fixed_label_images)):
fixed_label_geoms = ants.label_geometry_measures(fixed_label_images[i])
fixed_label_ids = np.array(fixed_label_geoms['Label'])
moving_label_geoms = ants.label_geometry_measures(moving_label_images[i])
moving_label_ids = np.array(moving_label_geoms['Label'])
common_label_ids.append(np.intersect1d(moving_label_ids, fixed_label_ids))
total_number_of_labels = len(common_label_ids[i])
if verbose:
print("Common label ids for image pair ", str(i), ": ", common_label_ids[i])
if len(common_label_ids) == 0:
raise ValueError("No common labels for image pair " + str(i))

if verbose:
print("Total number of labels: " + str(total_number_of_labels))

##############################
#
# Linear transform
#
##############################

linear_xfrm = None
if type_of_linear_transform is not None:

if verbose:
print("\n\nComputing linear transform.\n")

if total_number_of_labels < 3:
raise ValueError(" Number of labels must be >= 3.")

fixed_centers_of_mass = np.zeros((total_number_of_labels, image_dimension))
moving_centers_of_mass = np.zeros((total_number_of_labels, image_dimension))
deformable_multivariate_extras = list()

count = 0
for i in range(len(common_label_ids)):
for j in range(len(common_label_ids[i])):
label = common_label_ids[i][j]
if verbose:
print(" Finding center of mass for label " + str(label))
fixed_single_label_image = ants.threshold_image(fixed_label_images[i], label, label, 1, 0)
fixed_centers_of_mass[count, :] = ants.get_center_of_mass(fixed_single_label_image)
moving_single_label_image = ants.threshold_image(moving_label_images[i], label, label, 1, 0)
moving_centers_of_mass[count, :] = ants.get_center_of_mass(moving_single_label_image)
count += 1
if do_deformable:
deformable_multivariate_extras.append(["MSQ", fixed_single_label_image,
moving_single_label_image, label_image_weighting, 0])

linear_xfrm = ants.fit_transform_to_paired_points(moving_centers_of_mass,
fixed_centers_of_mass,
transform_type=type_of_linear_transform,
verbose=verbose)

linear_xfrm_file = output_prefix + "0GenericAffine.mat"
ants.write_transform(linear_xfrm, linear_xfrm_file)

##############################
#
# Deformable transform
#
##############################

if do_deformable:

if verbose:
print("\n\nComputing deformable transform using images.\n")

do_quick = False
do_repro = False

if "Quick" in type_of_transform:
do_quick = True
elif "Repro" in type_of_transform:
do_repro = True
random_seed = str(1)

intensity_metric_parameter = None
spline_distance = 26
if "[" in type_of_transform and "]" in type_of_transform:
subtype_of_transform = type_of_transform.split("[")[1].split("]")[0]
if not ('bo' in subtype_of_transform or 'so' in subtype_of_transform):
raise ValueError("See only 'so' or 'bo' transforms are available.")
if "," in subtype_of_transform:
subtype_of_transform_args = subtype_of_transform.split(",")
subtype_of_transform = subtype_of_transform_args[0]
intensity_metric_parameter = subtype_of_transform_args[1]
if len(subtype_of_transform_args) > 2:
spline_distance = subtype_of_transform_args[2]

syn_stage = list()

intensity_metric = None
if fixed_intensity_images is not None and len(fixed_intensity_images) > 0:
if do_quick:
intensity_metric = "MI"
if intensity_metric_parameter is None:
intensity_metric_parameter = 32
if not do_quick or do_repro:
intensity_metric = "CC"
if intensity_metric_parameter is None:
intensity_metric_parameter = 2
for i in range(1, len(fixed_intensity_images)):
syn_stage.append("--metric")
metric_string = "%s[%s,%s,%s,%s]" % (
intensity_metric,
get_pointer_string(fixed_intensity_images[i]),
get_pointer_string(moving_intensity_images[i]),
1.0, intensity_metric_parameter)
syn_stage.append(metric_string)

for kk in range(len(deformable_multivariate_extras)):
syn_stage.append("--metric")
metricString = "%s[%s,%s,%s,%s]" % (
"MSQ",
get_pointer_string(deformable_multivariate_extras[kk][1]),
get_pointer_string(deformable_multivariate_extras[kk][2]),
1.0, 0.0)
syn_stage.append(metricString)

syn_shrink_factors = "8x4x2x1"
syn_smoothing_sigmas = "3x2x1x0vox"

if do_quick:
syn_convergence = "[100x70x50x0,1e-6,10]"
else:
syn_convergence = "[100x70x50x20,1e-6,10]"

syn_stage.append("--convergence")
syn_stage.append(syn_convergence)
syn_stage.append("--shrink-factors")
syn_stage.append(syn_shrink_factors)
syn_stage.append("--smoothing-sigmas")
syn_stage.append(syn_smoothing_sigmas)

if 'b' in subtype_of_transform:
syn_stage.insert(0, "BSplineSyN[0.1," + str(spline_distance) + ",0,3]")
else:
syn_stage.insert(0, "SyN[0.1,3,0]")
syn_stage.insert(0, "--transform")

args = ["-d", str(image_dimension),
"-r", linear_xfrm_file,
"-o", output_prefix]
args.append(syn_stage)

fixed_mask_string = 'NA'
if fixed_mask is not None:
fixed_mask_binary = fixed_mask != 0
fixed_mask_string = get_pointer_string(fixed_mask_binary)

moving_mask_string = 'NA'
if moving_mask is not None:
moving_mask_binary = moving_mask != 0
moving_mask_string = get_pointer_string(moving_mask_binary)

mask_option = "[%s,%s]" % (fixed_mask_string, moving_mask_string)

args.append("-x")
args.append(mask_option)

args = list(itertools.chain.from_iterable(
itertools.repeat(x, 1)
if isinstance(x, str)
else x for x in args))

args.append("--float")
args.append("1")

if random_seed is not None:
args.append("--random-seed")
args.append(random_seed)

if verbose:
args.append("-v")
args.append("1")

processed_args = process_arguments(args)
if verbose:
print("antsRegistration " + ' '.join(processed_args))

libfn = get_lib_fn("antsRegistration")
deformable_registration_exit_error = libfn(processed_args)

if deformable_registration_exit_error != 0:
raise RuntimeError(f"Registration failed with error code {deformable_registration_exit_error}")

all_xfrms = sorted(set(glob.glob(output_prefix + "*" + "[0-9]*")))

find_inverse_warps = np.where([re.search("[0-9]InverseWarp.nii.gz", ff) for ff in all_xfrms])[0]
find_forward_warps = np.where([re.search("[0-9]Warp.nii.gz", ff) for ff in all_xfrms])[0]

if len(find_inverse_warps) > 0:
fwdtransforms = list(reversed([ff for idx, ff in enumerate(all_xfrms) if idx != find_inverse_warps[0]]))
invtransforms = [ff for idx, ff in enumerate(all_xfrms) if idx != find_forward_warps[0]]
else:
fwdtransforms = list(reversed(all_xfrms))
invtransforms = all_xfrms

if verbose:
print("\n\nResulting transforms:")
print(" fwdtransforms: ", fwdtransforms)
print(" invtransforms: ", invtransforms)

return {
"fwdtransforms": fwdtransforms,
"invtransforms": invtransforms,
}


16 changes: 16 additions & 0 deletions tests/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def setUp(self):
"QuickRigid",
"DenseRigid",
"BOLDRigid",
"antsRegistrationSyNQuick[b,32,26]",
"antsRegistrationSyNQuick[s]",
"antsRegistrationSyNRepro[s]",
"antsRegistrationSyN[s]"
}

def tearDown(self):
Expand Down Expand Up @@ -451,5 +455,17 @@ def test_motion_correction(self):
fi = ants.image_read(ants.get_ants_data('ch2'))
mytx = ants.motion_correction( fi )

def test_label_image_registration(self):
fi = ants.image_read(ants.get_ants_data('r16'))
mi = ants.image_read(ants.get_ants_data('r64'))
fi = ants.resample_image(fi, (60,60), 1, 0)
mi = ants.resample_image(mi, (60,60), 1, 0)
fi_seg = ants.threshold_image(fi, "Kmeans", 3)-1
mi_seg = ants.threshold_image(mi, "Kmeans", 3)-1
mytx = ants.label_image_registration([fi_seg],
[mi_seg],
fixed_intensity_images=fi,
moving_intensity_images=mi)

if __name__ == "__main__":
run_tests()