Skip to content

Commit

Permalink
Merge pull request #33 from juglab/fix-quality-issues
Browse files Browse the repository at this point in the history
Fix quality issues
  • Loading branch information
tibuch authored Oct 6, 2022
2 parents 4dfb9eb + 414bd5d commit 30a0b4b
Showing 1 changed file with 35 additions and 10 deletions.
45 changes: 35 additions & 10 deletions cryocare/internals/CryoCAREDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@
class CryoCARE_Dataset(tf.keras.utils.Sequence):
def __init__(self, tomo_paths_odd=None, tomo_paths_even=None, n_samples_per_tomo=None,
extraction_shapes=None, mean=None, std=None,
sample_shape=(64, 64, 64), shuffle=True, n_normalization_samples=500):
sample_shape=(64, 64, 64), shuffle=True, n_normalization_samples=500, tilt_axis=None):
self.tomo_paths_odd = tomo_paths_odd
self.tomo_paths_even = tomo_paths_even
self.n_samples_per_tomo = n_samples_per_tomo
self.tilt_axis = tilt_axis

if self.tilt_axis is not None:
tilt_axis_index = ["Z", "Y", "X"].index(self.tilt_axis)
rot_axes = [0, 1, 2]
rot_axes.remove(tilt_axis_index)
self.rot_axes = tuple(rot_axes)
else:
self.rot_axes = None

self.extraction_shapes = extraction_shapes
self.mean = mean
Expand Down Expand Up @@ -49,11 +58,12 @@ def save(self, path):
extraction_shapes=self.extraction_shapes,
sample_shape=self.sample_shape,
shuffle=self.shuffle,
coords=self.coords)
coords=self.coords,
tilt_axis=self.tilt_axis)

@classmethod
def load(cls, path):
tmp = np.load(path)
tmp = np.load(path, allow_pickle=True)
tomo_paths_odd = [str(p) for p in tmp['tomo_paths_odd']]
tomo_paths_even = [str(p) for p in tmp['tomo_paths_even']]
mean = tmp['mean']
Expand All @@ -63,6 +73,10 @@ def load(cls, path):
sample_shape = tmp['sample_shape']
shuffle = tmp['shuffle']
coords = tmp['coords']
if isinstance(tmp['tilt_axis'], np.ndarray):
tilt_axis = None
else:
tilt_axis = tmp['tilt_axis']

ds = cls(tomo_paths_odd=tomo_paths_odd,
tomo_paths_even=tomo_paths_even,
Expand All @@ -71,7 +85,8 @@ def load(cls, path):
n_samples_per_tomo=n_samples_per_tomo,
extraction_shapes=extraction_shapes,
sample_shape=sample_shape,
shuffle=shuffle)
shuffle=shuffle,
tilt_axis=tilt_axis)
ds.coords = coords
return ds

Expand Down Expand Up @@ -120,7 +135,16 @@ def create_random_coords(self, z, y, x, n_samples):

return np.stack([z_coords, y_coords, x_coords], -1)

def random_swapper(self, x, y):
def augment(self, x, y):
if self.tilt_axis is not None:
if self.sample_shape[0] == self.sample_shape[1] and \
self.sample_shape[0] == self.sample_shape[2]:
rot_k = np.random.randint(0, 4, 1)

x[...,0] = np.rot90(x[...,0], k=rot_k, axes=self.rot_axes)
y[...,0] = np.rot90(y[...,0], k=rot_k, axes=self.rot_axes)


if np.random.rand() > 0.5:
return y, x
else:
Expand All @@ -140,8 +164,7 @@ def __getitem__(self, idx):
odd_subvolume = self.tomos_odd[tomo_index].data[z:z + self.sample_shape[0],
y:y + self.sample_shape[1],
x:x + self.sample_shape[2]]

return self.random_swapper(np.array(even_subvolume)[..., np.newaxis], np.array(odd_subvolume)[..., np.newaxis])
return self.augment(np.array(even_subvolume)[..., np.newaxis], np.array(odd_subvolume)[..., np.newaxis])

def __iter__(self):
for idx in self.indices:
Expand Down Expand Up @@ -182,7 +205,8 @@ def setup(self, tomo_paths_odd, tomo_paths_even, n_samples_per_tomo, validation_
n_samples_per_tomo * (1 - validation_fraction)),
extraction_shapes=train_extraction_shapes,
sample_shape=sample_shape,
shuffle=True, n_normalization_samples=n_normalization_samples)
shuffle=True, n_normalization_samples=n_normalization_samples,
tilt_axis=tilt_axis)

self.val_dataset = CryoCARE_Dataset(tomo_paths_odd=tomo_paths_odd,
tomo_paths_even=tomo_paths_even,
Expand All @@ -191,7 +215,8 @@ def setup(self, tomo_paths_odd, tomo_paths_even, n_samples_per_tomo, validation_
n_samples_per_tomo=int(n_samples_per_tomo * validation_fraction),
extraction_shapes=val_extraction_shapes,
sample_shape=sample_shape,
shuffle=False)
shuffle=False,
tilt_axis=None)

def save(self, path):
self.train_dataset.save(join(path, 'train_data.npz'))
Expand All @@ -211,7 +236,7 @@ def __compute_extraction_shapes__(self, even_path, odd_path, tilt_axis_index, sa
assert even.data.shape[1] > 2 * sample_shape[1]
assert even.data.shape[2] > 2 * sample_shape[2]

val_cut_off = int(even.data.shape[tilt_axis_index] * validation_fraction)
val_cut_off = int(even.data.shape[tilt_axis_index] * (1 - validation_fraction))
if ((even.data.shape[tilt_axis_index] - val_cut_off) < sample_shape[tilt_axis_index]) or val_cut_off < sample_shape[tilt_axis_index]:
val_cut_off = even.data.shape[tilt_axis_index] - sample_shape[tilt_axis_index] - 1

Expand Down

0 comments on commit 30a0b4b

Please sign in to comment.