From 1b84bfe0c9faacedcc0caf257488a029c4eb8c14 Mon Sep 17 00:00:00 2001 From: keisen Date: Tue, 27 Jul 2021 01:17:03 +0900 Subject: [PATCH 1/2] Add support for 'interpolation' option in input modifier. --- .../input_modifiers.py | 25 ++++++++++++++----- tf_keras_vis/utils/__init__.py | 15 +++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/tf_keras_vis/activation_maximization/input_modifiers.py b/tf_keras_vis/activation_maximization/input_modifiers.py index 36b4576..19811dc 100644 --- a/tf_keras_vis/activation_maximization/input_modifiers.py +++ b/tf_keras_vis/activation_maximization/input_modifiers.py @@ -5,6 +5,8 @@ import tensorflow as tf from scipy.ndimage.interpolation import rotate, zoom +from ..utils import order + class InputModifier(ABC): """Abstract class for defining an input modifier. @@ -51,12 +53,15 @@ def __call__(self, seed_input) -> np.ndarray: class Rotate(InputModifier): """An input modifier that introduces random rotation. """ - def __init__(self, axes=(1, 2), degree=3.0) -> None: + def __init__(self, axes=(1, 2), degree=3.0, interpolation='bilinear') -> None: """ Args: axes: The two axes that define the plane of rotation. Defaults to (1, 2). degree: The amount of rotation to apply. Defaults to 3.0. + interpolation: An integer or string. When integer, `interpolation`'s specification is + the same as `order` option of scipy-ndimage API. When string, `interpolation` MUST + be one of `"nearest"`, `"bilinear"` and `"cubic"`. Defaults to `"bilinear"`. Raises: ValueError: When axes is not a tuple of two ints. @@ -68,6 +73,7 @@ def __init__(self, axes=(1, 2), degree=3.0) -> None: self.axes = axes self.degree = float(degree) self.random_generator = np.random.default_rng() + self.order = order(interpolation) def __call__(self, seed_input) -> np.ndarray: ndim = len(seed_input.shape) @@ -80,7 +86,7 @@ def __call__(self, seed_input) -> np.ndarray: self.random_generator.uniform(-self.degree, self.degree), axes=self.axes, reshape=False, - order=1, + order=self.order, mode='reflect', prefilter=False) return seed_input @@ -89,26 +95,33 @@ def __call__(self, seed_input) -> np.ndarray: class Rotate2D(Rotate): """An input modifier for 2D that introduces random rotation. """ - def __init__(self, degree=3.0) -> None: + def __init__(self, degree=3.0, interpolation='bilinear') -> None: """ Args: degree: The amount of rotation to apply. Defaults to 3.0. + interpolation: An integer or string. When integer, `interpolation`'s specification is + the same as `order` option of scipy-ndimage API. When string, `interpolation` MUST + be one of `"nearest"`, `"bilinear"` and `"cubic"`. Defaults to `"bilinear"`. """ - super().__init__(axes=(1, 2), degree=degree) + super().__init__(axes=(1, 2), degree=degree, interpolation=interpolation) class Scale(InputModifier): """An input modifier that introduces randam scaling. """ - def __init__(self, low=0.9, high=1.1) -> None: + def __init__(self, low=0.9, high=1.1, interpolation='bilinear') -> None: """ Args: low (float, optional): Lower boundary of the zoom factor. Defaults to 0.9. high (float, optional): Higher boundary of the zoom factor. Defaults to 1.1. + interpolation: An integer or string. When integer, `interpolation`'s specification is + the same as `order` option of scipy-ndimage API. When string, `interpolation` MUST + be one of `"nearest"`, `"bilinear"` and `"cubic"`. Defaults to `"bilinear"`. """ self.low = low self.high = high self.random_generator = np.random.default_rng() + self.order = order(interpolation) def __call__(self, seed_input) -> np.ndarray: ndim = len(seed_input.shape) @@ -121,7 +134,7 @@ def __call__(self, seed_input) -> np.ndarray: _factor = factor = self.random_generator.uniform(self.low, self.high) factor *= np.ones(ndim - 2) factor = (1, ) + tuple(factor) + (1, ) - seed_input = zoom(seed_input, factor, order=1, mode='reflect', prefilter=False) + seed_input = zoom(seed_input, factor, order=self.order, mode='reflect', prefilter=False) if _factor > 1.0: indices = (self._central_crop_range(x, e) for x, e in zip(seed_input.shape, shape)) indices = (slice(start, stop) for start, stop in indices) diff --git a/tf_keras_vis/utils/__init__.py b/tf_keras_vis/utils/__init__.py index f24c5ef..df2ddca 100644 --- a/tf_keras_vis/utils/__init__.py +++ b/tf_keras_vis/utils/__init__.py @@ -132,3 +132,18 @@ def lower_precision_dtype(model): (isinstance(layer, tf.keras.Model) and is_mixed_precision(layer)): return layer.compute_dtype return model.dtype # pragma: no cover + + +def order(value): + if isinstance(value, int): + return value + if isinstance(value, str): + value = value.lower() + if value == 'nearest': + return 0 + if value == 'bilinear': + return 1 + if value == 'cubic': + return 3 + raise ValueError(f"{value} is not supported. " + "The value MUST be an integer or one of 'nearest', 'bilinear' or 'cubic'.") From f382fec54811757b73b635973fbe69c47df7194a Mon Sep 17 00:00:00 2001 From: keisen Date: Tue, 27 Jul 2021 01:26:19 +0900 Subject: [PATCH 2/2] Add 'convert_iterable_to_list' to listify() --- tf_keras_vis/utils/__init__.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/tf_keras_vis/utils/__init__.py b/tf_keras_vis/utils/__init__.py index df2ddca..d5e7570 100644 --- a/tf_keras_vis/utils/__init__.py +++ b/tf_keras_vis/utils/__init__.py @@ -1,4 +1,5 @@ import os +from collections.abc import Iterable from typing import Tuple import numpy as np @@ -39,27 +40,38 @@ def num_of_gpus() -> Tuple[int, int]: return 0, 0 -def listify(value, return_empty_list_if_none=True, convert_tuple_to_list=True) -> list: - """Ensures that the value is a list. +def listify(value, + return_empty_list_if_none=True, + convert_tuple_to_list=True, + convert_iterable_to_list=False) -> list: + """Ensures that `value` is a list. - If it is not a list, it creates a new list with `value` as an item. + If `value` is not a list, this function creates an new list that includes `value`. Args: value (object): A list or something else. - return_empty_list_if_none (bool, optional): When True (default), None you passed as `value` - will be converted to a empty list (i.e., `[]`). When False, None will be converted to - a list that has an None (i.e., `[None]`). Defaults to True. - convert_tuple_to_list (bool, optional): When True (default), a tuple you passed as `value` - will be converted to a list. When False, a tuple will be unconverted - (i.e., returning a tuple object that was passed as `value`). Defaults to True. + return_empty_list_if_none (bool, optional): When True (default), `None` you passed as + `value` will be converted to a empty list (i.e., `[]`). When False, `None` will be + converted to a list that contains an `None` (i.e., `[None]`). Defaults to True. + convert_tuple_to_list (bool, optional):When True (default), a tuple object you + passed as `value` will be converted to a list. When False, a tuple object will be + unconverted (i.e., returning a list of a tuple object). Defaults to True. + convert_iterable_to_list (bool, optional): When True (default), an iterable object you + passed as `value` will be converted to a list. When False, an iterable object will be + unconverted (i.e., returning a list of an iterable object). Defaults to False. Returns: - list: A list. When `value` is a tuple and `convert_tuple_to_list` is False, a tuple. + list: A list """ if not isinstance(value, list): if value is None and return_empty_list_if_none: value = [] elif isinstance(value, tuple) and convert_tuple_to_list: value = list(value) + elif isinstance(value, Iterable) and convert_iterable_to_list: + if not convert_tuple_to_list: + raise ValueError("When 'convert_tuple_to_list' option is False," + "'convert_iterable_to_list' option should also be False.") + value = list(value) else: value = [value] return value