diff --git a/docs/tools_scripts.md b/docs/tools_scripts.md index 1b224e1ae..46116ad4d 100644 --- a/docs/tools_scripts.md +++ b/docs/tools_scripts.md @@ -79,3 +79,182 @@ Description of arguments: - `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`. **Note**: This tool is still experimental. Some customized operators are not supported for now. And we only support `mattor` and `restorer` for now. + +#### List of supported models exportable to ONNX + +The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime. + +| Model | Config | Dynamic Shape | Batch Inference | Note | +| :------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------: | :-------------: | :---: | +| ESRGAN | [esrgan_x4c64b23g32_g1_400k_div2k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/restorers/esrgan/esrgan_x4c64b23g32_g1_400k_div2k.py) | Y | Y | | +| ESRGAN | [esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py) | Y | Y | | +| SRCNN | [srcnn_x4k915_g1_1000k_div2k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py) | Y | Y | | +| DIM | [dim_stage3_v16_pln_1x1_1000k_comp1k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/mattors/dim/dim_stage3_v16_pln_1x1_1000k_comp1k.py) | Y | Y | | +| GCA | [gca_r34_4x10_200k_comp1k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/mattors/gca/gca_r34_4x10_200k_comp1k.py) | N | Y | | +| IndexNet | [indexnet_mobv2_1x16_78k_comp1k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/mattors/indexnet/indexnet_mobv2_1x16_78k_comp1k.py) | Y | Y | | + +**Notes**: + +- *All models above are tested with Pytorch==1.6.0 and onnxruntime==1.5.1* +- If you meet any problem with the listed models above, please create an issue and it would be taken care of soon. For models not included in the list, please try to solve them by yourself. +- Because this feature is experimental and may change fast, please always try with the latest `mmcv` and `mmedit`. + +### Evaluate ONNX Models with ONNXRuntime (experimental) + +We prepare a tool `tools/deploy_test.py` to evaluate ONNX models with ONNX Runtime backend. + +#### Prerequisite + +- Install onnx and onnxruntime-gpu + + ```shell + pip install onnx onnxruntime-gpu + ``` + +#### Usage + +```bash +python tools/deploy_test.py \ + ${CONFIG_FILE} \ + ${ONNX_FILE} \ + --out ${OUTPUT_FILE} \ + --save-path ${SAVE_PATH} \ + ----cfg-options ${CFG_OPTIONS} \ +``` + +#### Description of all arguments + +- `config`: The path of a model config file. +- `model`: The path of an ONNX model file. +- `--out`: The path of output result file in pickle format. +- `--save-path`: The path to store images and if not given, it will not save image. +- `--cfg-options`: Override some settings in the used config file, the key-value pair in `xxx=yyy` format will be merged into config file. + +#### Results and Models + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelConfigDatasetMetricPyTorchONNX Runtime
ESRGAN + esrgan_x4c64b23g32_g1_400k_div2k.py + Set5PSNR28.270028.2619
SSIM0.77780.7784
Set14PSNR24.632824.6290
SSIM0.64910.6494
DIV2KPSNR26.653126.6532
SSIM0.73400.7340
ESRGAN + esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py + Set5PSNR30.642830.6307
SSIM0.85590.8565
Set14PSNR27.054327.0422
SSIM0.74470.7450
DIV2KPSNR29.335429.3354
SSIM0.82630.8263
SRCNN + srcnn_x4k915_g1_1000k_div2k.py + Set5PSNR28.431628.4120
SSIM0.80990.8106
Set14PSNR25.648625.6367
SSIM0.70140.7015
DIV2KPSNR27.746027.7460
SSIM0.78540.78543
+ +**Notes**: + +- All ONNX models are evaluated with dynamic shape on the datasets and images are preprocessed according to the original config file. +- This tool is still experimental, and we only support `restorer` for now. diff --git a/mmedit/core/export/__init__.py b/mmedit/core/export/__init__.py new file mode 100644 index 000000000..05e413aad --- /dev/null +++ b/mmedit/core/export/__init__.py @@ -0,0 +1,3 @@ +from .wrappers import ONNXRuntimeEditing + +__all__ = ['ONNXRuntimeEditing'] diff --git a/mmedit/core/export/wrappers.py b/mmedit/core/export/wrappers.py new file mode 100644 index 000000000..9b63c7c12 --- /dev/null +++ b/mmedit/core/export/wrappers.py @@ -0,0 +1,133 @@ +import os.path as osp +import warnings + +import numpy as np +import onnxruntime as ort +import torch +from torch import nn + +from mmedit.models import BaseMattor, BasicRestorer, build_model + + +def inference_with_session(sess, io_binding, output_names, input_tensor): + device_type = input_tensor.device.type + device_id = input_tensor.device.index + device_id = 0 if device_id is None else device_id + io_binding.bind_input( + name='input', + device_type=device_type, + device_id=device_id, + element_type=np.float32, + shape=input_tensor.shape, + buffer_ptr=input_tensor.data_ptr()) + for name in output_names: + io_binding.bind_output(name) + sess.run_with_iobinding(io_binding) + pred = io_binding.copy_outputs_to_cpu() + return pred + + +class ONNXRuntimeMattor(nn.Module): + + def __init__(self, sess, io_binding, output_names, base_model): + super(ONNXRuntimeMattor, self).__init__() + self.sess = sess + self.io_binding = io_binding + self.output_names = output_names + self.base_model = base_model + + def forward(self, + merged, + trimap, + meta, + test_mode=False, + save_image=False, + save_path=None, + iteration=None): + input_tensor = torch.cat((merged, trimap), 1).contiguous() + pred_alpha = inference_with_session(self.sess, self.io_binding, + self.output_names, input_tensor)[0] + + pred_alpha = pred_alpha.squeeze() + pred_alpha = self.base_model.restore_shape(pred_alpha, meta) + eval_result = self.base_model.evaluate(pred_alpha, meta) + + if save_image: + self.base_model.save_image(pred_alpha, meta, save_path, iteration) + + return {'pred_alpha': pred_alpha, 'eval_result': eval_result} + + +class RestorerGenerator(nn.Module): + + def __init__(self, sess, io_binding, output_names): + super(RestorerGenerator, self).__init__() + self.sess = sess + self.io_binding = io_binding + self.output_names = output_names + + def forward(self, x): + pred = inference_with_session(self.sess, self.io_binding, + self.output_names, x)[0] + pred = torch.from_numpy(pred) + return pred + + +class ONNXRuntimeRestorer(nn.Module): + + def __init__(self, sess, io_binding, output_names, base_model): + super(ONNXRuntimeRestorer, self).__init__() + self.sess = sess + self.io_binding = io_binding + self.output_names = output_names + self.base_model = base_model + restorer_generator = RestorerGenerator(self.sess, self.io_binding, + self.output_names) + base_model.generator = restorer_generator + + def forward(self, lq, gt=None, test_mode=False, **kwargs): + return self.base_model(lq, gt=gt, test_mode=test_mode, **kwargs) + + +class ONNXRuntimeEditing(nn.Module): + + def __init__(self, onnx_file, cfg, device_id): + super(ONNXRuntimeEditing, self).__init__() + ort_custom_op_path = '' + try: + from mmcv.ops import get_onnxruntime_op_path + ort_custom_op_path = get_onnxruntime_op_path() + except (ImportError, ModuleNotFoundError): + warnings.warn('If input model has custom op from mmcv, \ + you may have to build mmcv with ONNXRuntime from source.') + session_options = ort.SessionOptions() + # register custom op for onnxruntime + if osp.exists(ort_custom_op_path): + session_options.register_custom_ops_library(ort_custom_op_path) + sess = ort.InferenceSession(onnx_file, session_options) + providers = ['CPUExecutionProvider'] + options = [{}] + is_cuda_available = ort.get_device() == 'GPU' + if is_cuda_available: + providers.insert(0, 'CUDAExecutionProvider') + options.insert(0, {'device_id': device_id}) + + sess.set_providers(providers, options) + + self.sess = sess + self.device_id = device_id + self.io_binding = sess.io_binding() + self.output_names = [_.name for _ in sess.get_outputs()] + + base_model = build_model( + cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) + + if isinstance(base_model, BaseMattor): + WraperClass = ONNXRuntimeMattor + elif isinstance(base_model, BasicRestorer): + WraperClass = ONNXRuntimeRestorer + self.wraper = WraperClass(self.sess, self.io_binding, + self.output_names, base_model) + + def forward(self, **kwargs): + return self.wraper(**kwargs) diff --git a/requirements/tests.txt b/requirements/tests.txt index 2747d91ae..db2805286 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -2,5 +2,6 @@ codecov flake8 interrogate isort==4.3.21 +onnxruntime pytest pytest-runner diff --git a/setup.cfg b/setup.cfg index da38d8823..bacf0caf0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = mmedit -known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,pymatting,pytest,scipy,titlecase,torch,torchvision +known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,packaging,pymatting,pytest,scipy,titlecase,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_onnx_wraper.py b/tests/test_onnx_wraper.py new file mode 100644 index 000000000..29bce6cb0 --- /dev/null +++ b/tests/test_onnx_wraper.py @@ -0,0 +1,154 @@ +import os + +import mmcv +import numpy as np +import pytest +import torch +from packaging import version + +from mmedit.models import build_model + + +@pytest.mark.skipif(torch.__version__ == 'parrots', reason='skip parrots.') +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('1.4.0'), + reason='skip if torch=1.3.x') +def test_restorer_wraper(): + try: + import onnxruntime as ort + from mmedit.core.export.wrappers import (ONNXRuntimeEditing, + ONNXRuntimeRestorer) + except ImportError: + pytest.skip('ONNXRuntime is not available.') + + onnx_path = 'tmp.onnx' + scale = 4 + train_cfg = None + test_cfg = None + cfg = dict( + model=dict( + type='BasicRestorer', + generator=dict( + type='SRCNN', + channels=(3, 4, 2, 3), + kernel_sizes=(9, 1, 5), + upscale_factor=scale), + pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean')), + train_cfg=train_cfg, + test_cfg=test_cfg) + cfg = mmcv.Config(cfg) + + pytorch_model = build_model( + cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) + + # prepare data + inputs = torch.rand(1, 3, 2, 2) + targets = torch.rand(1, 3, 8, 8) + data_batch = {'lq': inputs, 'gt': targets} + + pytorch_model.forward = pytorch_model.forward_dummy + with torch.no_grad(): + torch.onnx.export( + pytorch_model, + inputs, + onnx_path, + input_names=['input'], + output_names=['output'], + export_params=True, + keep_initializers_as_inputs=False, + verbose=False, + opset_version=11) + + wrap_model = ONNXRuntimeEditing(onnx_path, cfg, 0) + # os.remove(onnx_path) + assert isinstance(wrap_model.wraper, ONNXRuntimeRestorer) + + if ort.get_device() == 'GPU': + data_batch = {'lq': inputs.cuda(), 'gt': targets.cuda()} + + with torch.no_grad(): + outputs = wrap_model(**data_batch, test_mode=True) + + assert isinstance(outputs, dict) + assert 'output' in outputs + output = outputs['output'] + assert isinstance(output, torch.Tensor) + assert output.shape == targets.shape + + +@pytest.mark.skipif(torch.__version__ == 'parrots', reason='skip parrots.') +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('1.4.0'), + reason='skip if torch=1.3.x') +def test_mattor_wraper(): + try: + import onnxruntime as ort + from mmedit.core.export.wrappers import (ONNXRuntimeEditing, + ONNXRuntimeMattor) + except ImportError: + pytest.skip('ONNXRuntime is not available.') + onnx_path = 'tmp.onnx' + train_cfg = None + test_cfg = dict(refine=False, metrics=['SAD', 'MSE', 'GRAD', 'CONN']) + cfg = dict( + model=dict( + type='DIM', + backbone=dict( + type='SimpleEncoderDecoder', + encoder=dict(type='VGG16', in_channels=4), + decoder=dict(type='PlainDecoder')), + pretrained='open-mmlab://mmedit/vgg16', + loss_alpha=dict(type='CharbonnierLoss', loss_weight=0.5), + loss_comp=dict(type='CharbonnierCompLoss', loss_weight=0.5)), + train_cfg=train_cfg, + test_cfg=test_cfg) + cfg = mmcv.Config(cfg) + + pytorch_model = build_model( + cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) + + img_shape = (32, 32) + merged = torch.rand(1, 3, img_shape[1], img_shape[0]) + trimap = torch.rand(1, 1, img_shape[1], img_shape[0]) + data_batch = {'merged': merged, 'trimap': trimap} + inputs = torch.cat([merged, trimap], dim=1) + + pytorch_model.forward = pytorch_model.forward_dummy + with torch.no_grad(): + torch.onnx.export( + pytorch_model, + inputs, + onnx_path, + input_names=['input'], + output_names=['output'], + export_params=True, + keep_initializers_as_inputs=False, + verbose=False, + opset_version=11) + + wrap_model = ONNXRuntimeEditing(onnx_path, cfg, 0) + os.remove(onnx_path) + assert isinstance(wrap_model.wraper, ONNXRuntimeMattor) + + if ort.get_device() == 'GPU': + merged = merged.cuda() + trimap = trimap.cuda() + data_batch = {'merged': merged, 'trimap': trimap} + + ori_alpha = np.random.random(img_shape).astype(np.float32) + ori_trimap = np.random.randint(256, size=img_shape).astype(np.float32) + data_batch['meta'] = [ + dict( + ori_alpha=ori_alpha, + ori_trimap=ori_trimap, + merged_ori_shape=img_shape) + ] + + with torch.no_grad(): + outputs = wrap_model(**data_batch, test_mode=True) + + assert isinstance(outputs, dict) + assert 'pred_alpha' in outputs + pred_alpha = outputs['pred_alpha'] + assert isinstance(pred_alpha, np.ndarray) + assert pred_alpha.shape == img_shape diff --git a/tools/deploy_test.py b/tools/deploy_test.py new file mode 100644 index 000000000..6bda9564c --- /dev/null +++ b/tools/deploy_test.py @@ -0,0 +1,86 @@ +import argparse + +import mmcv +from mmcv import Config, DictAction +from mmcv.parallel import MMDataParallel + +from mmedit.apis import single_gpu_test +from mmedit.core.export import ONNXRuntimeEditing +from mmedit.datasets import build_dataloader, build_dataset + + +def parse_args(): + parser = argparse.ArgumentParser(description='mmediting tester') + parser.add_argument('config', help='test config file path') + parser.add_argument('model', help='input model file') + parser.add_argument('--out', help='output result pickle file') + parser.add_argument( + '--save-path', + default=None, + type=str, + help='path to store images and if not given, will not save image') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # init distributed env first, since logger depends on the dist info. + distributed = False + + # build the dataloader + dataset = build_dataset(cfg.data.test) + + loader_cfg = { + **dict((k, cfg.data[k]) for k in ['workers_per_gpu'] if k in cfg.data), + **dict( + samples_per_gpu=1, + drop_last=False, + shuffle=False, + dist=distributed), + **cfg.data.get('test_dataloader', {}) + } + + data_loader = build_dataloader(dataset, **loader_cfg) + + # build the model + model = ONNXRuntimeEditing(args.model, cfg=cfg, device_id=0) + + args.save_image = args.save_path is not None + model = MMDataParallel(model, device_ids=[0]) + outputs = single_gpu_test( + model, + data_loader, + save_path=args.save_path, + save_image=args.save_image) + + print() + # print metrics + stats = dataset.evaluate(outputs) + for stat in stats: + print('Eval-{}: {}'.format(stat, stats[stat])) + + # save result pickle + if args.out: + print('writing results to {}'.format(args.out)) + mmcv.dump(outputs, args.out) + + +if __name__ == '__main__': + main()