Skip to content

Commit

Permalink
Pipelines for FBA (open-mmlab#209)
Browse files Browse the repository at this point in the history
* Added pipelines for FBA.

* Modified.

* Added test.

* Tiny.

* Added __repr__  funcs.

* Added use_cache.

* Added test.

* Tiny.

* Added unittest.

* Speed up via broadcasting.

* Tiny.

* Use online cache.

* Tiny.

* Tiny.

* Simplified.

* Tiny。

* Tiny.

* Tiny.

* Tiny.

* Polished the doc string.
  • Loading branch information
yaochaorui authored Feb 27, 2021
1 parent ebb9b37 commit 4f90c0c
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 10 deletions.
5 changes: 3 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
RandomLoadResizeBg)
from .matting_aug import (CompositeFg, GenerateSeg, GenerateSoftSeg,
GenerateTrimap, GenerateTrimapWithDistTransform,
MergeFgAndBg, PerturbBg)
MergeFgAndBg, PerturbBg, TransformTrimap)
from .normalization import Normalize, RescaleToZeroOne

__all__ = [
Expand All @@ -25,5 +25,6 @@
'MergeFgAndBg', 'CompositeFg', 'TemporalReverse', 'LoadImageFromFileList',
'GenerateFrameIndices', 'GenerateFrameIndiceswithPadding', 'FixedCrop',
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
'CropAroundFg', 'GetSpatialDiscountMask', 'GenerateTrimapWithDistTransform'
'CropAroundFg', 'GetSpatialDiscountMask',
'GenerateTrimapWithDistTransform', 'TransformTrimap'
]
30 changes: 24 additions & 6 deletions mmedit/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class LoadImageFromFile:
Default: 'bgr'.
save_original_img (bool): If True, maintain a copy of the image in
`results` dict with name of `f'ori_{key}'`. Default: False.
use_cache (bool): If True, load all images at once. Default: False.
kwargs (dict): Args for file client.
"""

Expand All @@ -30,6 +31,7 @@ def __init__(self,
flag='color',
channel_order='bgr',
save_original_img=False,
use_cache=False,
**kwargs):
self.io_backend = io_backend
self.key = key
Expand All @@ -38,6 +40,8 @@ def __init__(self,
self.channel_order = channel_order
self.kwargs = kwargs
self.file_client = None
self.use_cache = use_cache
self.cache = None

def __call__(self, results):
"""Call function.
Expand All @@ -49,13 +53,26 @@ def __call__(self, results):
Returns:
dict: A dict containing the processed data and information.
"""
filepath = str(results[f'{self.key}_path'])
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
filepath = str(results[f'{self.key}_path'])
img_bytes = self.file_client.get(filepath)
img = mmcv.imfrombytes(
img_bytes, flag=self.flag, channel_order=self.channel_order) # HWC

if self.use_cache:
if self.cache is None:
self.cache = dict()
if filepath in self.cache:
img = self.cache[filepath]
else:
img_bytes = self.file_client.get(filepath)
img = mmcv.imfrombytes(
img_bytes,
flag=self.flag,
channel_order=self.channel_order) # HWC
self.cache[filepath] = img
else:
img_bytes = self.file_client.get(filepath)
img = mmcv.imfrombytes(
img_bytes, flag=self.flag,
channel_order=self.channel_order) # HWC
results[self.key] = img
results[f'{self.key}_path'] = filepath
results[f'{self.key}_ori_shape'] = img.shape
Expand All @@ -68,7 +85,8 @@ def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (
f'(io_backend={self.io_backend}, key={self.key}, '
f'flag={self.flag}, save_original_img={self.save_original_img})')
f'flag={self.flag}, save_original_img={self.save_original_img}, '
f'channel_order={self.channel_order}, use_cache={self.use_cache})')
return repr_str


Expand Down
56 changes: 56 additions & 0 deletions mmedit/datasets/pipelines/matting_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __call__(self, results):
results['merged'] = merged
return results

def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str


@PIPELINES.register_module()
class GenerateTrimap:
Expand Down Expand Up @@ -568,3 +572,55 @@ def __repr__(self):
f'dilate_iter_range={self.dilate_iter_range}, '
f'blur_ksizes={self.blur_ksizes})')
return repr_str


@PIPELINES.register_module()
class TransformTrimap:
"""Generate two-channel trimap and encode it into six-channel.
This calss will generate a two-channel trimap composed of definite
foreground and backgroud masks and encode it into a six-channel trimap
using Gaussian blurs of the generated two-channel trimap at three
different scales. The transformed trimap has 6 channels.
Required key is "trimap", added key is "transformed_trimap".
Adopted from the following repository:
https://github.com/MarcoForte/FBA_Matting/blob/master/networks/transforms.py.
"""

def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
trimap = results['trimap']
assert len(trimap.shape) == 2
h, w = trimap.shape[:2]
# generate two-channel trimap
trimap2 = np.zeros((h, w, 2), dtype=np.uint8)
trimap2[trimap == 0, 0] = 255
trimap2[trimap == 255, 1] = 255
trimap_trans = np.zeros((h, w, 6), dtype=np.float32)
factor = np.array([[[0.02, 0.08, 0.16]]], dtype=np.float32)
for k in range(2):
if np.any(trimap2[:, :, k]):
dt_mask = -cv2.distanceTransform(255 - trimap2[:, :, k],
cv2.DIST_L2, 0)**2
dt_mask = dt_mask[..., None]
L = 320
trimap_trans[..., 3 * k:3 * k + 3] = np.exp(
dt_mask / (2 * ((factor * L)**2)))

results['transformed_trimap'] = trimap_trans
return results

def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
20 changes: 19 additions & 1 deletion tests/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def test_load_image_from_file():
assert repr(image_loader) == (
image_loader.__class__.__name__ +
('(io_backend=disk, key=lq, '
'flag=color, save_original_img=False)'))
'flag=color, save_original_img=False, channel_order=bgr, '
'use_cache=False)'))

results = dict(lq_path=path_baboon_x4)
config = dict(
Expand All @@ -66,6 +67,23 @@ def test_load_image_from_file():
np.testing.assert_almost_equal(results['ori_lq'], results['lq'])
assert id(results['ori_lq']) != id(results['lq'])

# test: use_cache
results = dict(gt_path=path_baboon)
config = dict(io_backend='disk', key='gt', use_cache=True)
image_loader = LoadImageFromFile(**config)
assert image_loader.cache is None
assert repr(image_loader) == (
image_loader.__class__.__name__ +
('(io_backend=disk, key=gt, '
'flag=color, save_original_img=False, channel_order=bgr, '
'use_cache=True)'))
results = image_loader(results)
assert image_loader.cache is not None
assert str(path_baboon) in image_loader.cache
assert results['gt'].shape == (480, 500, 3)
assert results['gt_path'] == str(path_baboon)
np.testing.assert_almost_equal(results['gt'], img_baboon)


def test_load_image_from_file_list():
path_baboon = Path(__file__).parent / 'data' / 'gt' / 'baboon.png'
Expand Down
30 changes: 29 additions & 1 deletion tests/test_trimap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from mmedit.datasets.pipelines import (CompositeFg, GenerateSeg,
GenerateSoftSeg, GenerateTrimap,
GenerateTrimapWithDistTransform,
MergeFgAndBg, PerturbBg)
MergeFgAndBg, PerturbBg,
TransformTrimap)


def check_keys_contain(result_keys, target_keys):
Expand Down Expand Up @@ -323,3 +324,30 @@ def test_generate_soft_seg():
'erode_iter_range=(1, 2), dilate_iter_range=(1, 2), '
'blur_ksizes=[(11, 11)])')
assert repr(generate_soft_seg) == repr_str


def test_transform_trimap():
results = dict()
transform = TransformTrimap()
target_keys = ['trimap', 'transformed_trimap']

with pytest.raises(KeyError):
results_transformed = transform(results)

with pytest.raises(AssertionError):
dummy_trimap = np.zeros((100, 100, 1), dtype=np.uint8)
results['trimap'] = dummy_trimap
results_transformed = transform(results)

results = dict()
# generate dummy trimap with shape (100,100)
dummy_trimap = np.zeros((100, 100), dtype=np.uint8)
dummy_trimap[:50, :50] = 255
results['trimap'] = dummy_trimap
results_transformed = transform(results)
assert check_keys_contain(results_transformed.keys(), target_keys)
assert results_transformed['trimap'].shape == dummy_trimap.shape
assert results_transformed[
'transformed_trimap'].shape[:2] == dummy_trimap.shape
repr_str = transform.__class__.__name__
assert repr(transform) == repr_str

0 comments on commit 4f90c0c

Please sign in to comment.