From 86e6a7882fdd994e2b9e1c696b28a8648ff4a89c Mon Sep 17 00:00:00 2001 From: KOLANICH Date: Sun, 6 Mar 2022 18:38:17 +0300 Subject: [PATCH] Allowed to provide discrete probability distribution into `OneOf` --- imgaug/augmenters/meta.py | 111 +++++++++++++++++++++++++++++++++++--- 1 file changed, 104 insertions(+), 7 deletions(-) diff --git a/imgaug/augmenters/meta.py b/imgaug/augmenters/meta.py index 9279a0a4d..9c30c0bd6 100644 --- a/imgaug/augmenters/meta.py +++ b/imgaug/augmenters/meta.py @@ -3467,7 +3467,7 @@ def __str__(self): str(self.random_order), augs_str, self.deterministic) -class OneOf(SomeOf): +class OneOf(Augmenter): """Augmenter that always executes exactly one of its children. **Supported dtypes**: @@ -3476,6 +3476,12 @@ class OneOf(SomeOf): Parameters ---------- + + p : float or imgaug.parameters.StochasticParameter, optional + Sets the probability with which the given augmenters will be applied to + input images/data. E.g. a value of ``0.5`` will result in ``50%`` of + all input images (or other augmentables) being augmented. + children : imgaug.augmenters.meta.Augmenter or list of imgaug.augmenters.meta.Augmenter The choices of augmenters to apply. @@ -3525,16 +3531,107 @@ class OneOf(SomeOf): """ - def __init__(self, children, + @property + def n(self): + """For compatibility to SomeOf""" + return 1 + + def __init__(self, children, p=None, seed=None, name=None, - random_state="deprecated", deterministic="deprecated"): - super(OneOf, self).__init__( - n=1, - children=children, - random_order=False, + random_state="deprecated", deterministic="deprecated",): + + Augmenter.__init__( + self, seed=seed, name=name, random_state=random_state, deterministic=deterministic) + if children is None: + children = [] + elif isinstance(children, Augmenter): + # this must be separate from `list.__init__(self, children)`, + # otherwise in `SomeOf(OneOf(...))` the OneOf(...) is + # interpreted as a list and OneOf's children become SomeOf's + # children + children = [children] + elif ia.is_iterable(children): + assert all([isinstance(child, Augmenter) for child in children]), ( + "Expected all children to be augmenters, got types %s." % ( + ", ".join([str(type(v)) for v in children]))) + else: + raise Exception("Expected None or Augmenter or list of Augmenter, " + "got %s." % (type(children),)) + self.children = children + + if p is None: + p = np.ones(len(children)) / len(children) + #p = iap.DiscreteUniform(len(children)) + + if isinstance(p, (list, tuple, np.ndarray)): + assert len(p) == len(children) + p = iap.Choice(range(len(children)), p=p) + + self.p = p + + # Added in 0.4.0. + def _augment_batch_(self, batch, random_state, parents, hooks): + with batch.propagation_hooks_ctx(self, hooks, parents): + samples = self.p.draw_samples((batch.nb_rows,), random_state=random_state) + + # For then_list: collect augmentables to be processed by then_list + # augmenters, apply them to the list, then map back to the output + # list. Analogous for else_list. + # TODO maybe this would be easier if augment_*() accepted a list + # that can contain Nones + for i, augmenter_id in enumerate(samples): + augmenter = self.children[augmenter_id] + if augmenter is not None: + batch_sub = batch.subselect_rows_by_indices([i]) + batch_sub = augmenter.augment_batch_( + batch_sub, + parents=parents + [self], + hooks=hooks + ) + batch = batch.invert_subselect_rows_by_indices_([i], + batch_sub) + + return batch + + def _to_deterministic(self): + aug = self.copy() + aug.children = [el.to_deterministic() for el in aug.children] + aug.deterministic = True + aug.random_state = self.random_state.derive_rng_() + return aug + + def get_parameters(self): + """See :func:`~imgaug.augmenters.meta.Augmenter.get_parameters`.""" + return [self.p] + + def get_children_lists(self): + """See :func:`~imgaug.augmenters.meta.Augmenter.get_children_lists`.""" + result = [] + if self.children is not None: + result.append(self.children) + return result + + def __str__(self): + pattern = ( + "%s(" + "p=%s, name=%s, children=%s, deterministic=%s" + ")") + return pattern % ( + self.__class__.__name__, self.p, self.name, self.children, + self.deterministic) + + def __iter__(self): + return iter(self.children) + + def __getitem__(self, k): + return self.children[k] + + +ABCMeta.register(SomeOf, OneOf) + class Sometimes(Augmenter): """Apply child augmenter(s) with a probability of `p`.