diff --git a/cryocare/internals/CryoCAREDataModule.py b/cryocare/internals/CryoCAREDataModule.py index 76ff1de..80fb130 100644 --- a/cryocare/internals/CryoCAREDataModule.py +++ b/cryocare/internals/CryoCAREDataModule.py @@ -134,8 +134,8 @@ def augment(self, x, y): rot_k = np.random.randint(0, 4, x.shape[0]) for i in range(x.shape[0]): - x[i] = np.rot90(x[i], k=rot_k[i], axes=self.rot_axes) - y[i] = np.rot90(y[i], k=rot_k[i], axes=self.rot_axes) + x[i,...,0] = np.rot90(x[i,...,0], k=rot_k[i], axes=self.rot_axes) + y[i,...,0] = np.rot90(y[i,...,0], k=rot_k[i], axes=self.rot_axes) if np.random.rand() > 0.5: