Skip to content

Commit

Permalink
Merge pull request #53 from juglab/fix_numPix_computation
Browse files Browse the repository at this point in the history
Fix num pix computation
  • Loading branch information
tibuch authored Nov 7, 2019
2 parents fbc0432 + 8fc0930 commit 8d9e5c3
Show file tree
Hide file tree
Showing 13 changed files with 1,425 additions and 252 deletions.
925 changes: 925 additions & 0 deletions examples/2D/denoising2D_BSD68/BSD68_reproducibility.ipynb

Large diffs are not rendered by default.

140 changes: 99 additions & 41 deletions examples/2D/denoising2D_RGB/01_training.ipynb

Large diffs are not rendered by default.

29 changes: 19 additions & 10 deletions examples/2D/denoising2D_RGB/02_prediction.ipynb

Large diffs are not rendered by default.

110 changes: 84 additions & 26 deletions examples/2D/denoising2D_SEM/01_training.ipynb

Large diffs are not rendered by default.

26 changes: 21 additions & 5 deletions examples/2D/denoising2D_SEM/02_prediction.ipynb

Large diffs are not rendered by default.

149 changes: 112 additions & 37 deletions examples/3D/01_training.ipynb

Large diffs are not rendered by default.

20 changes: 15 additions & 5 deletions examples/3D/02_prediction.ipynb

Large diffs are not rendered by default.

75 changes: 39 additions & 36 deletions n2v/internals/N2V_DataWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class N2V_DataWrapper(Sequence):
The manipulator used for the pixel replacement.
"""

def __init__(self, X, Y, batch_size, num_pix=1, shape=(64, 64),
def __init__(self, X, Y, batch_size, perc_pix=0.198, shape=(64, 64),
value_manipulation=None):
self.X, self.Y = X, Y
self.batch_size = batch_size
Expand All @@ -35,93 +35,94 @@ def __init__(self, X, Y, batch_size, num_pix=1, shape=(64, 64),
self.dims = len(shape)
self.n_chan = X.shape[-1]

num_pix = int(np.product(shape)/100.0 * perc_pix)
assert num_pix >= 1, "Number of blind-spot pixels is below one. At least {}% of pixels should be replaced.".format(100.0/np.product(shape))
print("{} blind-spots will be generated per training patch of size {}.".format(num_pix, shape))

if self.dims == 2:
self.patch_sampler = self.__subpatch_sampling2D__
self.box_size = np.round(np.sqrt(shape[0] * shape[1] / num_pix)).astype(np.int)
self.box_size = np.round(np.sqrt(100/perc_pix)).astype(np.int)
self.get_stratified_coords = self.__get_stratified_coords2D__
self.rand_float = self.__rand_float_coords2D__(self.box_size)
self.X_Batches = np.zeros([X.shape[0], shape[0], shape[1], X.shape[3]])
self.Y_Batches = np.zeros([Y.shape[0], shape[0], shape[1], Y.shape[3]])
elif self.dims == 3:
self.patch_sampler = self.__subpatch_sampling3D__
self.box_size = np.round(np.power(shape[0] * shape[1] * shape[2] / num_pix, 1/3.0)).astype(np.int)
self.box_size = np.round(np.sqrt(100 / perc_pix)).astype(np.int)
self.get_stratified_coords = self.__get_stratified_coords3D__
self.rand_float = self.__rand_float_coords3D__(self.box_size)
self.X_Batches = np.zeros([X.shape[0], shape[0], shape[1], shape[2], X.shape[4]])
self.Y_Batches = np.zeros([Y.shape[0], shape[0], shape[1], shape[2], Y.shape[4]])
else:
raise Exception('Dimensionality not supported.')

self.X_Batches = np.zeros((self.X.shape[0], *self.shape, self.n_chan), dtype=np.float32)
self.Y_Batches = np.zeros((self.Y.shape[0], *self.shape, 2*self.n_chan), dtype=np.float32)

def __len__(self):
return int(np.ceil(len(self.X) / float(self.batch_size)))

def on_epoch_end(self):
self.perm = np.random.permutation(len(self.X))
self.X_Batches *= 0
self.Y_Batches *= 0

def __getitem__(self, i):
idx = slice(i * self.batch_size, (i + 1) * self.batch_size)
idx = self.perm[idx]
self.patch_sampler(self.X, self.Y, self.X_Batches, self.Y_Batches, idx, self.range, self.shape)
self.patch_sampler(self.X, self.X_Batches, indices=idx, range=self.range, shape=self.shape)

for j in idx:
for c in range(self.n_chan):
for c in range(self.n_chan):
for j in idx:
coords = self.get_stratified_coords(self.rand_float, box_size=self.box_size,
shape=np.array(self.X_Batches.shape)[1:-1])

y_val = []
x_val = []
for k in range(len(coords)):
y_val.append(np.copy(self.Y_Batches[(j, *coords[k], ..., c)]))
x_val.append(self.value_manipulation(self.X_Batches[j, ..., c][...,np.newaxis], coords[k], self.dims))

self.Y_Batches[j,...,c] *= 0
self.Y_Batches[j,...,self.n_chan+c] *= 0
shape=self.shape)

for k in range(len(coords)):
self.Y_Batches[(j, *coords[k], c)] = y_val[k]
self.Y_Batches[(j, *coords[k], self.n_chan+c)] = 1
self.X_Batches[(j, *coords[k], c)] = x_val[k]
indexing = (j,) + coords + (c,)
indexing_mask = (j,) + coords + (c + self.n_chan, )
y_val = self.X_Batches[indexing]
x_val = self.value_manipulation(self.X_Batches[j, ..., c], coords, self.dims)

self.Y_Batches[indexing] = y_val
self.Y_Batches[indexing_mask] = 1
self.X_Batches[indexing] = x_val

return self.X_Batches[idx], self.Y_Batches[idx]

@staticmethod
def __subpatch_sampling2D__(X, Y, X_Batches, Y_Batches, indices, range, shape):
def __subpatch_sampling2D__(X, X_Batches, indices, range, shape):
for j in indices:
y_start = np.random.randint(0, range[0] + 1)
x_start = np.random.randint(0, range[1] + 1)
X_Batches[j] = X[j, y_start:y_start + shape[0], x_start:x_start + shape[1]]
Y_Batches[j] = Y[j, y_start:y_start + shape[0], x_start:x_start + shape[1]]
X_Batches[j] = np.copy(X[j, y_start:y_start + shape[0], x_start:x_start + shape[1]])

@staticmethod
def __subpatch_sampling3D__(X, Y, X_Batches, Y_Batches, indices, range, shape):
def __subpatch_sampling3D__(X, X_Batches, indices, range, shape):
for j in indices:
z_start = np.random.randint(0, range[0] + 1)
y_start = np.random.randint(0, range[1] + 1)
x_start = np.random.randint(0, range[2] + 1)
X_Batches[j] = X[j, z_start:z_start + shape[0], y_start:y_start + shape[1], x_start:x_start + shape[2]]
Y_Batches[j] = Y[j, z_start:z_start + shape[0], y_start:y_start + shape[1], x_start:x_start + shape[2]]
X_Batches[j] = np.copy(X[j, z_start:z_start + shape[0], y_start:y_start + shape[1], x_start:x_start + shape[2]])

@staticmethod
def __get_stratified_coords2D__(coord_gen, box_size, shape):
coords = []
box_count_y = int(np.ceil(shape[0] / box_size))
box_count_x = int(np.ceil(shape[1] / box_size))
x_coords = []
y_coords = []
for i in range(box_count_y):
for j in range(box_count_x):
y, x = next(coord_gen)
y = int(i * box_size + y)
x = int(j * box_size + x)
if (y < shape[0] and x < shape[1]):
coords.append((y, x))
return coords
y_coords.append(y)
x_coords.append(x)
return (y_coords, x_coords)

@staticmethod
def __get_stratified_coords3D__(coord_gen, box_size, shape):
coords = []
box_count_z = int(np.ceil(shape[0] / box_size))
box_count_y = int(np.ceil(shape[1] / box_size))
box_count_x = int(np.ceil(shape[2] / box_size))
x_coords = []
y_coords = []
z_coords = []
for i in range(box_count_z):
for j in range(box_count_y):
for k in range(box_count_x):
Expand All @@ -130,8 +131,10 @@ def __get_stratified_coords3D__(coord_gen, box_size, shape):
y = int(j * box_size + y)
x = int(k * box_size + x)
if (z < shape[0] and y < shape[1] and x < shape[2]):
coords.append((z, y, x))
return coords
z_coords.append(z)
y_coords.append(y)
x_coords.append(x)
return (z_coords, y_coords, x_coords)

@staticmethod
def __rand_float_coords2D__(boxsize):
Expand Down
4 changes: 2 additions & 2 deletions n2v/models/n2v_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ def train(self, X, validation_X, epochs=None, steps_per_epoch=None):
# Here we prepare the Noise2Void data. Our input is the noisy data X and as target we take X concatenated with
# a masking channel. The N2V_DataWrapper will take care of the pixel masking and manipulating.
training_data = N2V_DataWrapper(X, np.concatenate((X, np.zeros(X.shape, dtype=X.dtype)), axis=axes.index('C')),
self.config.train_batch_size, int(train_num_pix/100 * self.config.n2v_perc_pix),
self.config.train_batch_size, self.config.n2v_perc_pix,
self.config.n2v_patch_shape, manipulator)

# validation_Y is also validation_X plus a concatenated masking channel.
# To speed things up, we precompute the masking vo the validation data.
validation_Y = np.concatenate((validation_X, np.zeros(validation_X.shape, dtype=validation_X.dtype)), axis=axes.index('C'))
n2v_utils.manipulate_val_data(validation_X, validation_Y,
num_pix=int(val_num_pix/100 * self.config.n2v_perc_pix),
perc_pix=self.config.n2v_perc_pix,
shape=val_patch_shape,
value_manipulation=manipulator)

Expand Down
77 changes: 43 additions & 34 deletions n2v/utils/n2v_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ def get_subpatch(patch, coord, local_sub_patch_radius):
start = np.maximum(0, np.array(coord) - local_sub_patch_radius)
end = start + local_sub_patch_radius*2 + 1

start = np.append(start, 0)
end = np.append(end, patch.shape[-1])

shift = np.minimum(0, patch.shape - end)

start += shift
Expand Down Expand Up @@ -37,66 +34,78 @@ def normal_int(mean, sigma, w):


def pm_normal_withoutCP(local_sub_patch_radius):
def normal_withoutCP(patch, coord, dims):
rand_coords = random_neighbor(patch.shape, coord)
return patch[tuple(rand_coords)]
def normal_withoutCP(patch, coords, dims):
vals = []
for coord in zip(*coords):
rand_coords = random_neighbor(patch.shape, coord)
vals.append(patch[tuple(rand_coords)])
return vals
return normal_withoutCP


def pm_uniform_withCP(local_sub_patch_radius):
def random_neighbor_withCP_uniform(patch, coord, dims):
sub_patch = get_subpatch(patch, coord,local_sub_patch_radius)
rand_coords = [np.random.randint(0, s) for s in sub_patch.shape[0:dims]]
return sub_patch[tuple(rand_coords)]
def random_neighbor_withCP_uniform(patch, coords, dims):
vals = []
for coord in zip(*coords):
sub_patch = get_subpatch(patch, coord,local_sub_patch_radius)
rand_coords = [np.random.randint(0, s) for s in sub_patch.shape[0:dims]]
vals.append(sub_patch[tuple(rand_coords)])
return vals
return random_neighbor_withCP_uniform


def pm_normal_additive(pixel_gauss_sigma):
def pixel_gauss(patch, coord, dims):
return np.random.normal(patch[tuple(coord)], pixel_gauss_sigma)
def pixel_gauss(patch, coords, dims):
vals = []
for coord in zip(*coords):
vals.append(np.random.normal(patch[tuple(coord)], pixel_gauss_sigma))
return vals
return pixel_gauss


def pm_normal_fitted(local_sub_patch_radius):
def local_gaussian(patch, coord, dims):
sub_patch = get_subpatch(patch, coord, local_sub_patch_radius)
axis = tuple(range(dims))
return np.random.normal(np.mean(sub_patch, axis=axis), np.std(sub_patch, axis=axis))
def local_gaussian(patch, coords, dims):
vals = []
for coord in zip(*coords):
sub_patch = get_subpatch(patch, coord, local_sub_patch_radius)
axis = tuple(range(dims))
vals.append(np.random.normal(np.mean(sub_patch, axis=axis), np.std(sub_patch, axis=axis)))
return vals
return local_gaussian


def pm_identity(local_sub_patch_radius):
def identity(patch, coord, dims):
return patch[tuple(coord)]
def identity(patch, coords, dims):
vals = []
for coord in zip(*coords):
vals.append(patch[coord])
return vals
return identity


def manipulate_val_data(X_val, Y_val, num_pix=64, shape=(64, 64), value_manipulation=pm_uniform_withCP(5)):
def manipulate_val_data(X_val, Y_val, perc_pix=0.198, shape=(64, 64), value_manipulation=pm_uniform_withCP(5)):
dims = len(shape)
if dims == 2:
box_size = np.round(np.sqrt(shape[0] * shape[1] / num_pix)).astype(np.int)
box_size = np.round(np.sqrt(100/perc_pix)).astype(np.int)
get_stratified_coords = dw.__get_stratified_coords2D__
rand_float = dw.__rand_float_coords2D__(box_size)
elif dims == 3:
box_size = np.round(np.power(shape[0] * shape[1] * shape[2] / num_pix, 1 / 3.0)).astype(np.int)
box_size = np.round(np.sqrt(100/perc_pix)).astype(np.int)
get_stratified_coords = dw.__get_stratified_coords3D__
rand_float = dw.__rand_float_coords3D__(box_size)

n_chan = X_val.shape[-1]

Y_val *= 0
for j in tqdm(range(X_val.shape[0]), desc='Preparing validation data: '):
coords = get_stratified_coords(rand_float, box_size=box_size,
shape=np.array(X_val.shape)[1:-1])
y_val = []
x_val = []
for k in range(len(coords)):
y_val.append(np.copy(Y_val[(j, *coords[k], ...)]))
x_val.append(value_manipulation(X_val[j, ...], coords[k], dims))

Y_val[j] *= 0

for k in range(len(coords)):
for c in range(n_chan):
Y_val[(j, *coords[k], c)] = y_val[k][c]
Y_val[(j, *coords[k], n_chan+c)] = 1
X_val[(j, *coords[k], c)] = x_val[k][c]
for c in range(n_chan):
indexing = (j,) + coords + (c,)
indexing_mask = (j,) + coords + (c + n_chan,)
y_val = X_val[indexing]
x_val = value_manipulation(X_val[j, ..., c], coords, dims)

Y_val[indexing] = y_val
Y_val[indexing_mask] = 1
X_val[indexing] = x_val
2 changes: 1 addition & 1 deletion n2v/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.9'
__version__ = '0.1.10'
24 changes: 12 additions & 12 deletions tests/test_Noise2VoidDataWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def create_data(in_shape, out_shape):
def _sample2D(in_shape, out_shape, seed):
X, Y, X_Batches, Y_Batches, indices = create_data(in_shape, out_shape)
np.random.seed(seed)
N2V_DataWrapper.__subpatch_sampling2D__(X, Y, X_Batches, Y_Batches, indices,
N2V_DataWrapper.__subpatch_sampling2D__(X, X_Batches, indices,
range=in_shape[1:3]-out_shape[1:3], shape=out_shape[1:3])

assert ([*X_Batches.shape] == out_shape).all()
Expand All @@ -24,20 +24,18 @@ def _sample2D(in_shape, out_shape, seed):
range_x = in_shape[2] - out_shape[2]
for j in indices:
assert np.sum(X_Batches[j]) != 0
assert np.sum(Y_Batches[j]) != 0
y_start = np.random.randint(0, range_y + 1)
x_start = np.random.randint(0, range_x + 1)
assert np.sum(X_Batches[j] - X[j, y_start:y_start+out_shape[1], x_start:x_start+out_shape[2]]) == 0

for j in range(in_shape[0]):
if j not in indices:
assert np.sum(X_Batches[j]) == 0
assert np.sum(Y_Batches[j]) == 0

def _sample3D(in_shape, out_shape, seed):
X, Y, X_Batches, Y_Batches, indices = create_data(in_shape, out_shape)
np.random.seed(seed)
N2V_DataWrapper.__subpatch_sampling3D__(X, Y, X_Batches, Y_Batches, indices,
N2V_DataWrapper.__subpatch_sampling3D__(X, X_Batches, indices,
range=in_shape[1:4]-out_shape[1:4], shape=out_shape[1:4])

assert ([*X_Batches.shape] == out_shape).all()
Expand All @@ -47,7 +45,6 @@ def _sample3D(in_shape, out_shape, seed):
range_x = in_shape[3] - out_shape[3]
for j in indices:
assert np.sum(X_Batches[j]) != 0
assert np.sum(Y_Batches[j]) != 0
z_start = np.random.randint(0, range_z + 1)
y_start = np.random.randint(0, range_y + 1)
x_start = np.random.randint(0, range_x + 1)
Expand All @@ -56,7 +53,6 @@ def _sample3D(in_shape, out_shape, seed):
for j in range(in_shape[0]):
if j not in indices:
assert np.sum(X_Batches[j]) == 0
assert np.sum(Y_Batches[j]) == 0

_sample2D(np.array([20, 64, 64, 2]), np.array([20, 32, 32, 2]), 1)
_sample2D(np.array([10, 25, 25, 1]), np.array([10, 12, 12, 1]), 2)
Expand Down Expand Up @@ -115,28 +111,32 @@ def _getitem2D(y_shape):
else:
X = Y[:,:,:,:n_chan]
val_manipulator = random_neighbor_withCP_uniform
dw = N2V_DataWrapper(X, Y, 4, num_pix=16, shape=(32, 32), value_manipulation=val_manipulator)
dw = N2V_DataWrapper(X, Y, 4, perc_pix=0.198, shape=(32, 32), value_manipulation=val_manipulator)

x_batch, y_batch = dw.__getitem__(0)
assert x_batch.shape == (4, 32, 32, int(n_chan))
assert y_batch.shape == (4, 32, 32, int(2*n_chan))
assert np.sum(y_batch[:,:,:,n_chan:]) == 16*4*n_chan
# At least one pixel has to be a blind-spot per batch sample
assert np.sum(y_batch[..., n_chan:]) >= 4 * n_chan
# At most four pixels can be affected per batch sample
assert np.sum(y_batch[..., n_chan:]) <= 4*4 * n_chan

assert np.sum(X[:,:32,:32,:n_chan]*y_batch[:,:,:,n_chan:] - Y[:,:32,:32,:n_chan]*y_batch[:,:,:,n_chan:]) <= 10e-12

def _getitem3D(y_shape):
Y = create_data(y_shape)
n_chan = y_shape[-1]//2
X = Y[:,:,:,:,0][:,:,:,:,np.newaxis]
val_manipulator = random_neighbor_withCP_uniform
dw = N2V_DataWrapper(X, Y, 4, num_pix=64, shape=(32, 32, 32), value_manipulation=val_manipulator)
dw = N2V_DataWrapper(X, Y, 4, perc_pix=0.198, shape=(32, 32, 32), value_manipulation=val_manipulator)

x_batch, y_batch = dw.__getitem__(0)
assert x_batch.shape == (4, 32, 32, 32, 1)
assert y_batch.shape == (4, 32, 32, 32, 2)
assert np.sum(y_batch[:,:,:,:,1]) == 64*4
# At least one pixel has to be a blind-spot per batch sample
assert np.sum(y_batch[..., n_chan:]) >= 1*4 * n_chan
# At most 8 pixels can be affected per batch sample
assert np.sum(y_batch[..., n_chan:]) <= 8*4 * n_chan

assert np.sum(X[:,:32,:32,:32,:n_chan]*y_batch[:,:,:,:,n_chan:] - Y[:,:32,:32,:32,:n_chan]*y_batch[:,:,:,:,n_chan:]) <= 10e-12

_getitem2D(np.array([4, 32, 32, 2]))
_getitem2D(np.array([4, 64, 64, 2]))
Expand Down
Loading

0 comments on commit 8d9e5c3

Please sign in to comment.