Skip to content

Commit

Permalink
Fix augmentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
tibuch committed Oct 5, 2022
1 parent 92d2c5f commit 414bd5d
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions cryocare/internals/CryoCAREDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ def __init__(self, tomo_paths_odd=None, tomo_paths_even=None, n_samples_per_tomo
self.n_samples_per_tomo = n_samples_per_tomo
self.tilt_axis = tilt_axis

tilt_axis_index = ["Z", "Y", "X"].index(self.tilt_axis)
rot_axes = [0, 1, 2]
rot_axes.remove(tilt_axis_index)
self.rot_axes = tuple(rot_axes)
if self.tilt_axis is not None:
tilt_axis_index = ["Z", "Y", "X"].index(self.tilt_axis)
rot_axes = [0, 1, 2]
rot_axes.remove(tilt_axis_index)
self.rot_axes = tuple(rot_axes)
else:
self.rot_axes = None

self.extraction_shapes = extraction_shapes
self.mean = mean
Expand Down Expand Up @@ -60,7 +63,7 @@ def save(self, path):

@classmethod
def load(cls, path):
tmp = np.load(path)
tmp = np.load(path, allow_pickle=True)
tomo_paths_odd = [str(p) for p in tmp['tomo_paths_odd']]
tomo_paths_even = [str(p) for p in tmp['tomo_paths_even']]
mean = tmp['mean']
Expand All @@ -70,7 +73,10 @@ def load(cls, path):
sample_shape = tmp['sample_shape']
shuffle = tmp['shuffle']
coords = tmp['coords']
tilt_axis = tmp['tilt_axis']
if isinstance(tmp['tilt_axis'], np.ndarray):
tilt_axis = None
else:
tilt_axis = tmp['tilt_axis']

ds = cls(tomo_paths_odd=tomo_paths_odd,
tomo_paths_even=tomo_paths_even,
Expand Down Expand Up @@ -131,11 +137,12 @@ def create_random_coords(self, z, y, x, n_samples):

def augment(self, x, y):
if self.tilt_axis is not None:
rot_k = np.random.randint(0, 4, x.shape[0])
if self.sample_shape[0] == self.sample_shape[1] and \
self.sample_shape[0] == self.sample_shape[2]:
rot_k = np.random.randint(0, 4, 1)

for i in range(x.shape[0]):
x[i,...,0] = np.rot90(x[i,...,0], k=rot_k[i], axes=self.rot_axes)
y[i,...,0] = np.rot90(y[i,...,0], k=rot_k[i], axes=self.rot_axes)
x[...,0] = np.rot90(x[...,0], k=rot_k, axes=self.rot_axes)
y[...,0] = np.rot90(y[...,0], k=rot_k, axes=self.rot_axes)


if np.random.rand() > 0.5:
Expand All @@ -157,7 +164,6 @@ def __getitem__(self, idx):
odd_subvolume = self.tomos_odd[tomo_index].data[z:z + self.sample_shape[0],
y:y + self.sample_shape[1],
x:x + self.sample_shape[2]]

return self.augment(np.array(even_subvolume)[..., np.newaxis], np.array(odd_subvolume)[..., np.newaxis])

def __iter__(self):
Expand Down

0 comments on commit 414bd5d

Please sign in to comment.