diff --git a/cryocare/internals/CryoCAREDataModule.py b/cryocare/internals/CryoCAREDataModule.py index 80fb130..3e158d6 100644 --- a/cryocare/internals/CryoCAREDataModule.py +++ b/cryocare/internals/CryoCAREDataModule.py @@ -17,10 +17,13 @@ def __init__(self, tomo_paths_odd=None, tomo_paths_even=None, n_samples_per_tomo self.n_samples_per_tomo = n_samples_per_tomo self.tilt_axis = tilt_axis - tilt_axis_index = ["Z", "Y", "X"].index(self.tilt_axis) - rot_axes = [0, 1, 2] - rot_axes.remove(tilt_axis_index) - self.rot_axes = tuple(rot_axes) + if self.tilt_axis is not None: + tilt_axis_index = ["Z", "Y", "X"].index(self.tilt_axis) + rot_axes = [0, 1, 2] + rot_axes.remove(tilt_axis_index) + self.rot_axes = tuple(rot_axes) + else: + self.rot_axes = None self.extraction_shapes = extraction_shapes self.mean = mean @@ -60,7 +63,7 @@ def save(self, path): @classmethod def load(cls, path): - tmp = np.load(path) + tmp = np.load(path, allow_pickle=True) tomo_paths_odd = [str(p) for p in tmp['tomo_paths_odd']] tomo_paths_even = [str(p) for p in tmp['tomo_paths_even']] mean = tmp['mean'] @@ -70,7 +73,10 @@ def load(cls, path): sample_shape = tmp['sample_shape'] shuffle = tmp['shuffle'] coords = tmp['coords'] - tilt_axis = tmp['tilt_axis'] + if isinstance(tmp['tilt_axis'], np.ndarray): + tilt_axis = None + else: + tilt_axis = tmp['tilt_axis'] ds = cls(tomo_paths_odd=tomo_paths_odd, tomo_paths_even=tomo_paths_even, @@ -131,11 +137,12 @@ def create_random_coords(self, z, y, x, n_samples): def augment(self, x, y): if self.tilt_axis is not None: - rot_k = np.random.randint(0, 4, x.shape[0]) + if self.sample_shape[0] == self.sample_shape[1] and \ + self.sample_shape[0] == self.sample_shape[2]: + rot_k = np.random.randint(0, 4, 1) - for i in range(x.shape[0]): - x[i,...,0] = np.rot90(x[i,...,0], k=rot_k[i], axes=self.rot_axes) - y[i,...,0] = np.rot90(y[i,...,0], k=rot_k[i], axes=self.rot_axes) + x[...,0] = np.rot90(x[...,0], k=rot_k, axes=self.rot_axes) + y[...,0] = np.rot90(y[...,0], k=rot_k, axes=self.rot_axes) if np.random.rand() > 0.5: @@ -157,7 +164,6 @@ def __getitem__(self, idx): odd_subvolume = self.tomos_odd[tomo_index].data[z:z + self.sample_shape[0], y:y + self.sample_shape[1], x:x + self.sample_shape[2]] - return self.augment(np.array(even_subvolume)[..., np.newaxis], np.array(odd_subvolume)[..., np.newaxis]) def __iter__(self):