diff --git a/n2v/internals/N2V_DataGenerator.py b/n2v/internals/N2V_DataGenerator.py index 588ae32..33ca496 100644 --- a/n2v/internals/N2V_DataGenerator.py +++ b/n2v/internals/N2V_DataGenerator.py @@ -131,8 +131,9 @@ def generate_patches_from_list(self, data, num_patches_per_img=None, shape=(256, """ patches = [] for img in data: - p = self.generate_patches(img, num_patches=num_patches_per_img, shape=shape, augment=augment) - patches.append(p) + for s in range(img.shape[0]): + p = self.generate_patches(img[s][np.newaxis], num_patches=num_patches_per_img, shape=shape, augment=augment) + patches.append(p) patches = np.concatenate(patches, axis=0) @@ -210,12 +211,11 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2 patches = [] if n_dims == 2: for i in range(num_patches): - s = np.random.randint(0, data.shape[0]) y, x = np.random.randint(0, data.shape[1] - shape[0] + 1), np.random.randint(0, data.shape[ 2] - shape[ 1] + 1) - patches.append(data[s, y:y + shape[0], x:x + shape[1]]) + patches.append(data[0, y:y + shape[0], x:x + shape[1]]) if len(patches) > 1: return np.stack(patches) @@ -223,13 +223,12 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2 return np.array(patches)[np.newaxis] elif n_dims == 3: for i in range(num_patches): - s = np.random.randint(0, data.shape[0]) z, y, x = np.random.randint(0, data.shape[1] - shape[0] + 1), np.random.randint(0, data.shape[ 2] - shape[ 1] + 1), np.random.randint( 0, data.shape[3] - shape[2] + 1) - patches.append(data[s, z:z + shape[0], y:y + shape[1], x:x + shape[2]]) + patches.append(data[0, z:z + shape[0], y:y + shape[1], x:x + shape[2]]) if len(patches) > 1: return np.stack(patches)