-
Notifications
You must be signed in to change notification settings - Fork 2
/
testDTI.py
214 lines (174 loc) · 7.99 KB
/
testDTI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#%%
from TumorGrowthToolkit.FK_DTI import FK_DTI_Solver
from TumorGrowthToolkit.FK import Solver as FK_Solver
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors
import time
import scipy.ndimage
import nibabel as nib
import os
import torch
from scipy.ndimage import binary_dilation
def elongate_tensor_along_main_axis_torch(tensor_array, scale_factor):
tensor_array = tensor_array.float()
e, v = torch.linalg.eigh(tensor_array)
# Original sum of eigenvalues
original_sum = torch.sum(e, dim=-1, keepdim=True)
# Identify and scale the maximum eigenvalue
max_eigenvalue_indices = torch.argmax(e, dim=-1, keepdim=True)
max_eigenvalues = torch.gather(e, -1, max_eigenvalue_indices)
scaled_max_eigenvalues = max_eigenvalues * scale_factor
# Calculate the difference introduced by scaling
difference = scaled_max_eigenvalues - max_eigenvalues
# Prepare to adjust the other eigenvalues to keep the sum constant
adjustment = difference / 2
mask = torch.ones_like(e, dtype=torch.bool)
mask.scatter_(-1, max_eigenvalue_indices, 0) # Mask out the max eigenvalue
# Adjust the other two eigenvalues
e_adjusted = torch.where(mask, e - adjustment, e)
e_adjusted_sum = torch.sum(e_adjusted, dim=-1, keepdim=True)
# Calculate final adjustments due to precision errors
final_adjustment = (original_sum - e_adjusted_sum) / 3
e_final = e_adjusted + torch.where(mask, final_adjustment, torch.zeros_like(final_adjustment))
# Ensure the scaled max eigenvalue is set correctly
e_final.scatter_(-1, max_eigenvalue_indices, scaled_max_eigenvalues)
# Reconstruct the tensor
tensor_array_prime = v @ torch.diag_embed(e_final) @ v.transpose(-2, -1)
return tensor_array_prime
#%%
tissue = nib.load("/mnt/8tb_slot8/jonas/datasets/TGM/rgbResults/sub-tgm051_ses-preop_space-sri_dti_RGB.nii.gz").get_fdata()
print('shape: (x, y, z, fa-diffusion) :', tissue.shape)
seg = nib.load("/mnt/8tb_slot8/jonas/datasets/TGM/tgm/tgm051/preop/sub-tgm051_ses-preop_space-sri_seg.nii.gz").get_fdata()
brainTissue = nib.load("/mnt/8tb_slot8/jonas/datasets/TGM/tgm/tgm051/preop/sub-tgm051_ses-preop_space-sri_tissuemask.nii.gz").get_fdata()
tissueTensor= nib.load("/mnt/8tb_slot8/jonas/datasets/TGM/rgbResults/sub-tgm051_ses-preop_space-sri_dti_tensor.nii.gz").get_fdata()
brainMask = nib.load("/mnt/8tb_slot8/jonas/datasets/TGM/tgm/tgm051/preop/sub-tgm051_ses-preop_space-sri_brainmask.nii.gz").get_fdata()
#normalize the tensor
tissueTensor = tissueTensor/np.max(tissueTensor.flatten())
affine = nib.load("/mnt/8tb_slot8/jonas/datasets/TGM/tgm/tgm051/preop/sub-tgm051_ses-preop_space-sri_seg.nii.gz").affine
# only use diagonal elements.
CSFMask = binary_dilation(brainTissue == 1, iterations = 1)
tissue[CSFMask] = 0
tissueTensor[CSFMask] = 0
#%%
def makeXYZ_rgb_from_tensor(tensor, brainMask):
output = np.zeros(tissue.shape)
output[:,:,:,0] = tensor[:,:,:,0,0]
output[:,:,:,1] = tensor[:,:,:,1,1]
output[:,:,:,2] = tensor[:,:,:,2,2]
#set the mean to 0.2 and clip at 1 for stability reasons
output /= np.mean(output[brainMask >0])#.flatten()[output.flatten()>0.0])
output *= 0.2
output[output>1] = 1
output[output<0] = 0
return output
tissueFromTensor = makeXYZ_rgb_from_tensor(tissueTensor, brainMask)
#%%
# Scale factor to elongate the tensor along its main axis
scale_factor = 0.5#1.5 # 250000.0
# Apply the transformation
tensor_array_prime = elongate_tensor_along_main_axis_torch(torch.from_numpy(tissueTensor), scale_factor).numpy()
tissueFromScaledTensor = makeXYZ_rgb_from_tensor(tensor_array_prime, brainMask)
#%%
plt.title('Tensor tissue')
plt.imshow((tissueFromTensor/np.max(tissueFromTensor))[:,:,75,:] )
plt.show()
plt.title('Scaled Tensor tissue by factor: ' + str(scale_factor) )
plt.imshow((tissueFromScaledTensor/np.max(tissueFromScaledTensor))[:,:,75,:] )
plt.show()
plt.title('')
plt.hist(tissueFromTensor[brainMask>0].flatten(), bins=100, alpha=0.5, label='tissueFromTensor')
plt.hist(tissueFromScaledTensor[brainMask>0].flatten(), bins=100, alpha=0.5, label='tissueFromScaledTensor_' + str(scale_factor))
plt.legend(loc='upper right')
print("sum of tissueFromTensor: ", np.sum(tissueFromTensor[brainMask>0].flatten()))
print("sum of tissueFromScaledTensor: ", np.sum(tissueFromScaledTensor[brainMask>0].flatten()))
# %%
parameters = {
'Dw': 0.7, # maximum diffusion coefficient
'rho': 0.4, # Proliferation rate
'rgb':tissueFromScaledTensor, # diffusion tissue map as shown above
'diffusionTensorExponent': 1.0, # exponent for the diffusion tensor, 1.0 for linear relationship
'NxT1_pct': 0.45, # tumor position [%]
'NyT1_pct': 0.32,
'NzT1_pct': 0.60,
'init_scale': 1., #scale of the initial gaussian
'resolution_factor': 1, #resultion scaling for calculations
'verbose': True, #printing timesteps
'time_series_solution_Nt': 64 # number of timesteps in the output
}
x = int(tissue.shape[0]*parameters["NxT1_pct"])
y = int(tissue.shape[1]*parameters["NyT1_pct"])
z = int(tissue.shape[2]*parameters["NzT1_pct"])
com = scipy.ndimage.measurements.center_of_mass(seg)
plt.imshow(brainTissue[:,:,z])
plt.show()
plt.imshow(tissue[:,:,z])
#plt.imshow(seg[:,:,z],alpha=0.5)
plt.title('Fractional Anisotropy Tissue')
plt.xlabel('y')
plt.ylabel('x')
plt.scatter(y,x, c='r')
plt.show()
#plt.imshow(seg[:,:,z],alpha=0.5)
#%%
# Run the FK_solver and plot the results
start_time = time.time()
fK_DTI_Solver = FK_DTI_Solver(parameters)
result = fK_DTI_Solver.solve()
end_time = time.time() # Store the end time
execution_time = int(end_time - start_time) # Calculate the difference
print(f"Execution Time: {execution_time} seconds")
if result['success']:
print("Simulation successful!")
else:
print("Error occurred:", result['error'])
# Create custom color maps
cmap1 = matplotlib.colors.LinearSegmentedColormap.from_list('my_cmap', ['black', 'white'], 256)
cmap2 = matplotlib.colors.LinearSegmentedColormap.from_list('my_cmap2', ['black', 'green', 'yellow', 'red'], 256)
#%% run normal FK
gm = brainTissue == 2
wm = brainTissue == 3
gm[CSFMask] = 0
wm[CSFMask] = 0
parametersFK = {
'Dw': 0.7, # maximum diffusion coefficient
'rho': 0.3, # Proliferation rate
'gm' : gm,
'wm' : wm,
'NxT1_pct': 0.45, # tumor position [%]
'NyT1_pct': 0.32,
'NzT1_pct': 0.60,
'init_scale': 1., #scale of the initial gaussian
'resolution_factor':1, #resultion scaling for calculations
'verbose': True, #printing timesteps
'time_series_solution_Nt': 64 # number of timesteps in the output
}
fkSolver = FK_Solver(parametersFK)
resultFK = fkSolver.solve()
#%%
# Calculate the slice index
NzT = int(parameters['NzT1_pct'] * tissue.shape[2]) +5
def plot_time_series(wm_data, time_series_data, slice_index):
plt.figure(figsize=(24, 12))
# Generate 8 indices evenly spaced across the time series length
time_points = np.linspace(0, time_series_data.shape[0] - 1, 8, dtype=int)
for i, t in enumerate(time_points):
plt.subplot(2, 4, i + 1) # 2 rows, 4 columns, current subplot index
plt.imshow(wm_data[:, :, slice_index], cmap=cmap1, vmin=0, vmax=1, alpha=1)
plt.imshow(time_series_data[t, :, :, slice_index], cmap=cmap2, vmin=0, vmax=1, alpha=0.65)
plt.title(f"Time Slice {t + 1}")
plt.tight_layout()
plt.show()
plot_time_series(np.mean(tissue, axis=3),result['time_series'], NzT)
plot_time_series(np.mean(tissue, axis=3),resultFK['time_series'], NzT)
#%%
plt.imshow(brainTissue[:,:,z]>0,alpha=0.5*(brainTissue[:,:,z]==0), cmap='gray')
plt.imshow(seg[:,:,z],alpha=0.5*(seg[:,:,z]>0), cmap='Greens')
plt.imshow(result['final_state'][:,:,z], alpha=0.5*(result['final_state'][:,:,z]>0.001), cmap = "Reds")
plt.title('Tumor')
#%% save results
path = "/mnt/8tb_slot8/jonas/workingDirDatasets/tgm/dtiFirstTests/tgm051/"
os.makedirs(path, exist_ok=True)
nib.save(nib.Nifti1Image(result['final_state'], affine=affine), path + "resultTensor.nii.gz")
nib.save(nib.Nifti1Image(tissueFromTensor, affine=affine), path + "tissueFromTensor.nii.gz")
# %%