Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowed to provide discrete probability distribution into OneOf #813

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 104 additions & 7 deletions imgaug/augmenters/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3467,7 +3467,7 @@ def __str__(self):
str(self.random_order), augs_str, self.deterministic)


class OneOf(SomeOf):
class OneOf(Augmenter):
"""Augmenter that always executes exactly one of its children.

**Supported dtypes**:
Expand All @@ -3476,6 +3476,12 @@ class OneOf(SomeOf):

Parameters
----------

p : float or imgaug.parameters.StochasticParameter, optional
Sets the probability with which the given augmenters will be applied to
input images/data. E.g. a value of ``0.5`` will result in ``50%`` of
all input images (or other augmentables) being augmented.

children : imgaug.augmenters.meta.Augmenter or list of imgaug.augmenters.meta.Augmenter
The choices of augmenters to apply.

Expand Down Expand Up @@ -3525,16 +3531,107 @@ class OneOf(SomeOf):

"""

def __init__(self, children,
@property
def n(self):
"""For compatibility to SomeOf"""
return 1

def __init__(self, children, p=None,
seed=None, name=None,
random_state="deprecated", deterministic="deprecated"):
super(OneOf, self).__init__(
n=1,
children=children,
random_order=False,
random_state="deprecated", deterministic="deprecated",):

Augmenter.__init__(
self,
seed=seed, name=name,
random_state=random_state, deterministic=deterministic)

if children is None:
children = []
elif isinstance(children, Augmenter):
# this must be separate from `list.__init__(self, children)`,
# otherwise in `SomeOf(OneOf(...))` the OneOf(...) is
# interpreted as a list and OneOf's children become SomeOf's
# children
children = [children]
elif ia.is_iterable(children):
assert all([isinstance(child, Augmenter) for child in children]), (
"Expected all children to be augmenters, got types %s." % (
", ".join([str(type(v)) for v in children])))
else:
raise Exception("Expected None or Augmenter or list of Augmenter, "
"got %s." % (type(children),))
self.children = children

if p is None:
p = np.ones(len(children)) / len(children)
#p = iap.DiscreteUniform(len(children))

if isinstance(p, (list, tuple, np.ndarray)):
assert len(p) == len(children)
p = iap.Choice(range(len(children)), p=p)

self.p = p

# Added in 0.4.0.
def _augment_batch_(self, batch, random_state, parents, hooks):
with batch.propagation_hooks_ctx(self, hooks, parents):
samples = self.p.draw_samples((batch.nb_rows,), random_state=random_state)

# For then_list: collect augmentables to be processed by then_list
# augmenters, apply them to the list, then map back to the output
# list. Analogous for else_list.
# TODO maybe this would be easier if augment_*() accepted a list
# that can contain Nones
for i, augmenter_id in enumerate(samples):
augmenter = self.children[augmenter_id]
if augmenter is not None:
batch_sub = batch.subselect_rows_by_indices([i])
batch_sub = augmenter.augment_batch_(
batch_sub,
parents=parents + [self],
hooks=hooks
)
batch = batch.invert_subselect_rows_by_indices_([i],
batch_sub)

return batch

def _to_deterministic(self):
aug = self.copy()
aug.children = [el.to_deterministic() for el in aug.children]
aug.deterministic = True
aug.random_state = self.random_state.derive_rng_()
return aug

def get_parameters(self):
"""See :func:`~imgaug.augmenters.meta.Augmenter.get_parameters`."""
return [self.p]

def get_children_lists(self):
"""See :func:`~imgaug.augmenters.meta.Augmenter.get_children_lists`."""
result = []
if self.children is not None:
result.append(self.children)
return result

def __str__(self):
pattern = (
"%s("
"p=%s, name=%s, children=%s, deterministic=%s"
")")
return pattern % (
self.__class__.__name__, self.p, self.name, self.children,
self.deterministic)

def __iter__(self):
return iter(self.children)

def __getitem__(self, k):
return self.children[k]


ABCMeta.register(SomeOf, OneOf)


class Sometimes(Augmenter):
"""Apply child augmenter(s) with a probability of `p`.
Expand Down