Skip to content

Commit

Permalink
[Feature]: Add onnx evaluation tool (open-mmlab#279)
Browse files Browse the repository at this point in the history
* update doc for onnx eval tool

* add ort eval tool

* add mattor test support

* update document

* update doc

* add format to html table

* add onnx wraper test

* fix lint

* better onnx unit test

* skip if torch==1.3.x

Co-authored-by: q.yao <[email protected]>
  • Loading branch information
RunningLeon and q.yao authored Apr 29, 2021
1 parent 29d6487 commit 4d2a23d
Show file tree
Hide file tree
Showing 7 changed files with 557 additions and 1 deletion.
179 changes: 179 additions & 0 deletions docs/tools_scripts.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,182 @@ Description of arguments:
- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.

**Note**: This tool is still experimental. Some customized operators are not supported for now. And we only support `mattor` and `restorer` for now.

#### List of supported models exportable to ONNX

The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime.

| Model | Config | Dynamic Shape | Batch Inference | Note |
| :------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------: | :-------------: | :---: |
| ESRGAN | [esrgan_x4c64b23g32_g1_400k_div2k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/restorers/esrgan/esrgan_x4c64b23g32_g1_400k_div2k.py) | Y | Y | |
| ESRGAN | [esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py) | Y | Y | |
| SRCNN | [srcnn_x4k915_g1_1000k_div2k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py) | Y | Y | |
| DIM | [dim_stage3_v16_pln_1x1_1000k_comp1k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/mattors/dim/dim_stage3_v16_pln_1x1_1000k_comp1k.py) | Y | Y | |
| GCA | [gca_r34_4x10_200k_comp1k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/mattors/gca/gca_r34_4x10_200k_comp1k.py) | N | Y | |
| IndexNet | [indexnet_mobv2_1x16_78k_comp1k.py](https://github.com/open-mmlab/mmediting/blob/master/configs/mattors/indexnet/indexnet_mobv2_1x16_78k_comp1k.py) | Y | Y | |

**Notes**:

- *All models above are tested with Pytorch==1.6.0 and onnxruntime==1.5.1*
- If you meet any problem with the listed models above, please create an issue and it would be taken care of soon. For models not included in the list, please try to solve them by yourself.
- Because this feature is experimental and may change fast, please always try with the latest `mmcv` and `mmedit`.

### Evaluate ONNX Models with ONNXRuntime (experimental)

We prepare a tool `tools/deploy_test.py` to evaluate ONNX models with ONNX Runtime backend.

#### Prerequisite

- Install onnx and onnxruntime-gpu

```shell
pip install onnx onnxruntime-gpu
```

#### Usage

```bash
python tools/deploy_test.py \
${CONFIG_FILE} \
${ONNX_FILE} \
--out ${OUTPUT_FILE} \
--save-path ${SAVE_PATH} \
----cfg-options ${CFG_OPTIONS} \
```

#### Description of all arguments

- `config`: The path of a model config file.
- `model`: The path of an ONNX model file.
- `--out`: The path of output result file in pickle format.
- `--save-path`: The path to store images and if not given, it will not save image.
- `--cfg-options`: Override some settings in the used config file, the key-value pair in `xxx=yyy` format will be merged into config file.

#### Results and Models

<table border="1" class="docutils">
<tr>
<th align="center">Model</th>
<th align="center">Config</th>
<th align="center">Dataset</th>
<th align="center">Metric</th>
<th align="center">PyTorch</th>
<th align="center">ONNX Runtime</th>
</tr>
<tr>
<td align="center" rowspan="6">ESRGAN</td>
<td align="center" rowspan="6">
<code>esrgan_x4c64b23g32_g1_400k_div2k.py</code>
</td>
<td align="center" rowspan="2">Set5</td>
<td align="center">PSNR</td>
<td align="center">28.2700</td>
<td align="center">28.2619</td>
</tr>
<tr>
<td align="center">SSIM</td>
<td align="center">0.7778</td>
<td align="center">0.7784</td>
</tr>
<tr>
<td align="center" rowspan="2">Set14</td>
<td align="center">PSNR</td>
<td align="center">24.6328</td>
<td align="center">24.6290</td>
</tr>
<tr>
<td align="center">SSIM</td>
<td align="center">0.6491</td>
<td align="center">0.6494</td>
</tr>
<tr>
<td align="center" rowspan="2">DIV2K</td>
<td align="center">PSNR</td>
<td align="center">26.6531</td>
<td align="center">26.6532</td>
</tr>
<tr>
<td align="center">SSIM</td>
<td align="center">0.7340</td>
<td align="center">0.7340</td>
</tr>
<tr>
<td align="center" rowspan="6">ESRGAN</td>
<td align="center" rowspan="6">
<code>esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py</code>
</td>
<td align="center" rowspan="2">Set5</td>
<td align="center">PSNR</td>
<td align="center">30.6428</td>
<td align="center">30.6307</td>
</tr>
<tr>
<td align="center">SSIM</td>
<td align="center">0.8559</td>
<td align="center">0.8565</td>
</tr>
<tr>
<td align="center" rowspan="2">Set14</td>
<td align="center">PSNR</td>
<td align="center">27.0543</td>
<td align="center">27.0422</td>
</tr>
<tr>
<td align="center">SSIM</td>
<td align="center">0.7447</td>
<td align="center">0.7450</td>
</tr>
<tr>
<td align="center" rowspan="2">DIV2K</td>
<td align="center">PSNR</td>
<td align="center">29.3354</td>
<td align="center">29.3354</td>
</tr>
<tr>
<td align="center">SSIM</td>
<td align="center">0.8263</td>
<td align="center">0.8263</td>
</tr>
<tr>
<td align="center" rowspan="6">SRCNN</td>
<td align="center" rowspan="6">
<code>srcnn_x4k915_g1_1000k_div2k.py</code>
</td>
<td align="center" rowspan="2">Set5</td>
<td align="center">PSNR</td>
<td align="center">28.4316</td>
<td align="center">28.4120</td>
</tr>
<tr>
<td align="center">SSIM</td>
<td align="center">0.8099</td>
<td align="center">0.8106</td>
</tr>
<tr>
<td align="center" rowspan="2">Set14</td>
<td align="center">PSNR</td>
<td align="center">25.6486</td>
<td align="center">25.6367</td>
</tr>
<tr>
<td align="center">SSIM</td>
<td align="center">0.7014</td>
<td align="center">0.7015</td>
</tr>
<tr>
<td align="center" rowspan="2">DIV2K</td>
<td align="center">PSNR</td>
<td align="center">27.7460</td>
<td align="center">27.7460</td>
</tr>
<tr>
<td align="center">SSIM</td>
<td align="center">0.7854</td>
<td align="center">0.78543</td>
</tr>
</table>

**Notes**:

- All ONNX models are evaluated with dynamic shape on the datasets and images are preprocessed according to the original config file.
- This tool is still experimental, and we only support `restorer` for now.
3 changes: 3 additions & 0 deletions mmedit/core/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .wrappers import ONNXRuntimeEditing

__all__ = ['ONNXRuntimeEditing']
133 changes: 133 additions & 0 deletions mmedit/core/export/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os.path as osp
import warnings

import numpy as np
import onnxruntime as ort
import torch
from torch import nn

from mmedit.models import BaseMattor, BasicRestorer, build_model


def inference_with_session(sess, io_binding, output_names, input_tensor):
device_type = input_tensor.device.type
device_id = input_tensor.device.index
device_id = 0 if device_id is None else device_id
io_binding.bind_input(
name='input',
device_type=device_type,
device_id=device_id,
element_type=np.float32,
shape=input_tensor.shape,
buffer_ptr=input_tensor.data_ptr())
for name in output_names:
io_binding.bind_output(name)
sess.run_with_iobinding(io_binding)
pred = io_binding.copy_outputs_to_cpu()
return pred


class ONNXRuntimeMattor(nn.Module):

def __init__(self, sess, io_binding, output_names, base_model):
super(ONNXRuntimeMattor, self).__init__()
self.sess = sess
self.io_binding = io_binding
self.output_names = output_names
self.base_model = base_model

def forward(self,
merged,
trimap,
meta,
test_mode=False,
save_image=False,
save_path=None,
iteration=None):
input_tensor = torch.cat((merged, trimap), 1).contiguous()
pred_alpha = inference_with_session(self.sess, self.io_binding,
self.output_names, input_tensor)[0]

pred_alpha = pred_alpha.squeeze()
pred_alpha = self.base_model.restore_shape(pred_alpha, meta)
eval_result = self.base_model.evaluate(pred_alpha, meta)

if save_image:
self.base_model.save_image(pred_alpha, meta, save_path, iteration)

return {'pred_alpha': pred_alpha, 'eval_result': eval_result}


class RestorerGenerator(nn.Module):

def __init__(self, sess, io_binding, output_names):
super(RestorerGenerator, self).__init__()
self.sess = sess
self.io_binding = io_binding
self.output_names = output_names

def forward(self, x):
pred = inference_with_session(self.sess, self.io_binding,
self.output_names, x)[0]
pred = torch.from_numpy(pred)
return pred


class ONNXRuntimeRestorer(nn.Module):

def __init__(self, sess, io_binding, output_names, base_model):
super(ONNXRuntimeRestorer, self).__init__()
self.sess = sess
self.io_binding = io_binding
self.output_names = output_names
self.base_model = base_model
restorer_generator = RestorerGenerator(self.sess, self.io_binding,
self.output_names)
base_model.generator = restorer_generator

def forward(self, lq, gt=None, test_mode=False, **kwargs):
return self.base_model(lq, gt=gt, test_mode=test_mode, **kwargs)


class ONNXRuntimeEditing(nn.Module):

def __init__(self, onnx_file, cfg, device_id):
super(ONNXRuntimeEditing, self).__init__()
ort_custom_op_path = ''
try:
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with ONNXRuntime from source.')
session_options = ort.SessionOptions()
# register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})

sess.set_providers(providers, options)

self.sess = sess
self.device_id = device_id
self.io_binding = sess.io_binding()
self.output_names = [_.name for _ in sess.get_outputs()]

base_model = build_model(
cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)

if isinstance(base_model, BaseMattor):
WraperClass = ONNXRuntimeMattor
elif isinstance(base_model, BasicRestorer):
WraperClass = ONNXRuntimeRestorer
self.wraper = WraperClass(self.sess, self.io_binding,
self.output_names, base_model)

def forward(self, **kwargs):
return self.wraper(**kwargs)
1 change: 1 addition & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ codecov
flake8
interrogate
isort==4.3.21
onnxruntime
pytest
pytest-runner
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmedit
known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,pymatting,pytest,scipy,titlecase,torch,torchvision
known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,packaging,pymatting,pytest,scipy,titlecase,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
Loading

0 comments on commit 4d2a23d

Please sign in to comment.