Skip to content

Commit

Permalink
ESRGAN example (#847)
Browse files Browse the repository at this point in the history
Summary:
This PR adds [ESRGAN](https://github.com/xinntao/Real-ESRGAN) upscaler example. It is a draft for now, see Notes.

Tested with [`RealESRGAN_x4plus`](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth). I did not find safetensors versions of these models, but it is supported.

Compile:
```
python compile.py --model-path "RealESRGAN_x4plus.safetensors" --include-constants True
```

Demo:
```
python demo.py --module-path ./tmp/ESRGANModel/test.so --input-image-path "nvidiafu.png" --output-image-path "nvidiafu_upscaled.png"
```

Input image:
![nvidiafu](https://github.com/facebookincubator/AITemplate/assets/106811348/69b86c01-7b3d-46e1-b460-e7ee16af1414)

Upscaled image:
![nvidiafu_upscaled](https://github.com/facebookincubator/AITemplate/assets/106811348/b3dc6c6c-fe6a-437e-a48a-b5c0c8959f8c)

## Notes
* This depends on #846.
* Demo pipeline is basic - no batch support, no alpha channel support.
    * Dynamic batch is untested but should work.
    * ESRGAN does support alpha channel, it is just not implemented in the demo pipeline.
* Dynamic shape (height, width) is supported.
* Correct scale must be given to `demo.py`.
* Model arch can be configured using compile.py arguments `--num-in-ch`, `--num-out-ch`, `--num-feat`, `--num-block`, `--num-grow-ch`, `--scale`.
    * `--num-block` and `--scale` are most likely to be changed.
    * e.g. `--num-block 6` for `RealESRGAN_x4plus_anime_6B` or `--scale 2` for `RealESRGAN_x2plus`.
* Scale != 4 models e.g. `RealESRGAN_x2plus` support static shape only.
* Needs documentation.

After dependency is merged and documentation I will mark this PR as ready.

Demo pipeline could be improved, but it is just a demo, full feature implementation is left to the developer integrating this module.

Pull Request resolved: #847

Reviewed By: sgrigory

Differential Revision: D48425425

Pulled By: aakhundov

fbshipit-source-id: 38602581fd8991540240d8c58d274af905803876
  • Loading branch information
hlky authored and facebook-github-bot committed Aug 17, 2023
1 parent 179bfb2 commit e3d3214
Show file tree
Hide file tree
Showing 4 changed files with 633 additions and 0 deletions.
137 changes: 137 additions & 0 deletions examples/08_esrgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
## ESRGAN Example

In this example, we show how to build fast AIT modules for ESRGAN models, and benchmark/run them.

### Build Dependencies

First, clone, build, and install AITemplate [per the README instructions](https://github.com/facebookincubator/AITemplate#clone-the-code).

This AIT ESRGAN example depends on `torch`, `click` and optionally `safetensors`. You could install them using `pip`.

### Download the ESRGAN model

We have tested the following ESRGAN models.

[RealESRGAN_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)

Model architecture:
```
num_in_ch: 3,
num_out_ch: 3,
num_feat: 64,
num_block: 23,
num_grow_ch: 32,
scale: 4,
```


[RealESRGAN_x4plus_anime_6B](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)

Model architecture:
```
num_in_ch: 3,
num_out_ch: 3,
num_feat: 64,
num_block: 6,
num_grow_ch: 32,
scale: 4,
```


[RealESRGAN_x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth)

Model architecture:
```
num_in_ch: 3,
num_out_ch: 3,
num_feat: 64,
num_block: 23,
num_grow_ch: 32,
scale: 2,
```

A database of ESRGAN models can be found [here](https://upscale.wiki/wiki/Model_Database).

Safetensors versions are supported.


### Build AIT modules for ESRGAN

Build the AIT modules by running `compile.py`.

```
Usage: compile.py [OPTIONS]
Options:
--model-path TEXT model path. supports torch or safetensors
--width <INTEGER INTEGER>... Minimum and maximum width
--height <INTEGER INTEGER>... Minimum and maximum height
--batch-size <INTEGER INTEGER>...
Minimum and maximum batch size
--include-constants BOOLEAN include constants (model weights) with
compiled model
--num-in-ch INTEGER Number of in channels
--num-out-ch INTEGER Number of out channels
--num-feat INTEGER Number of intermediate features
--num-block INTEGER Number of RRDB layers
--num-grow-ch INTEGER Number of channels for each growth
--scale INTEGER Scale
--use-fp16-acc BOOLEAN use fp16 accumulation
--convert-conv-to-gemm BOOLEAN convert 1x1 conv to gemm
--work-dir TEXT Work directory
--model-name TEXT Model name
--help Show this message and exit.
```

Use `--num-in-ch`, `--num-out-ch`, `--num-feat`, `--num-block`, `--num-grow-ch`, `--scale` options to set the ESRGAN model architecture. The default values are for `RealESRGAN_x4plus` architecture.

`--width` and `--height` require a minimum and maximum value, the compiled module supports the range of resolutions. However, with 2x model architecture, only static shape is supported, the maximum value for each dimension is used. Defaults are `64` and `1024`.

`--batch-size` is supported for 4x model architecture, provide minimum and maximum values. Default is `1 1`.

Use `--include-constants False` to compile the module without model weights.

AIT modules are compatible with all ESRGAN models with the same model architecture. This can simplify deployment by compiling a module without model weights then applying weights at runtime by using AIT mapped weights (see `map_rrdb` in `./modeling/rrdbnet.py`) with the module's `set_many_constants_with_tensors`.

In our tests an ESRGAN module compiled with weights is approximately `~38MB`, and `~6.5MB` without.

Examples:

```
python compile.py --model-path "RealESRGAN_x4plus.safetensors"
```

```
python compile.py --model-path "RealESRGAN_x4plus_anime_6B.pth" --num-block 6 --model-name RealESRGAN_x4plus_anime_6B
```

```
python compile.py --model-path "RealESRGAN_x2plus.pth" --scale 2 --model-name RealESRGAN_x2plus --width 512 512 --height 512 512
```


#### Multi-GPU profiling
AIT needs to do profiling to select the best algorithms for CUTLASS and CK.
To enable multiple GPUs for profiling, use the environment variable `CUDA_VISIBLE_DEVICES` on NVIDIA platform and `HIP_VISIBLE_DEVICES` on AMD platform.


### Run Models

`demo.py` provides example usage of ESRGAN modules.

```
Usage: demo.py [OPTIONS]
Options:
--module-path TEXT the AIT module path
--input-image-path TEXT path to input image
--output-image-path TEXT path to output image
--scale INTEGER Scale of ESRGAN model
--help Show this message and exit.
```

`--scale` must match the scale of the model architecture.

Limitations:
* Demo does not support multiple images/batch size.
* Demo does not support images with alpha channel.
176 changes: 176 additions & 0 deletions examples/08_esrgan/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging

import click
import safetensors.torch
import torch

from aitemplate.compiler import compile_model
from aitemplate.frontend import IntVar, Tensor
from aitemplate.testing import detect_target

from modeling.rrdbnet import map_rrdb, mark_output, RRDBNet


@click.command()
@click.option(
"--model-path",
default="RealESRGAN_x4plus.safetensors",
help="model path. supports torch or safetensors",
)
@click.option(
"--width",
default=(64, 1024),
type=(int, int),
nargs=2,
help="Minimum and maximum width",
)
@click.option(
"--height",
default=(64, 1024),
type=(int, int),
nargs=2,
help="Minimum and maximum height",
)
@click.option(
"--batch-size",
default=(1, 1),
type=(int, int),
nargs=2,
help="Minimum and maximum batch size",
)
@click.option(
"--include-constants",
default=True,
type=bool,
help="include constants (model weights) with compiled model",
)
@click.option(
"--num-in-ch",
default=3,
type=int,
help="Number of in channels",
)
@click.option(
"--num-out-ch",
default=3,
type=int,
help="Number of out channels",
)
@click.option(
"--num-feat",
default=64,
type=int,
help="Number of intermediate features",
)
@click.option(
"--num-block",
default=23,
type=int,
help="Number of RRDB layers",
)
@click.option(
"--num-grow-ch",
default=32,
type=int,
help="Number of channels for each growth",
)
@click.option(
"--scale",
default=4,
type=int,
help="Scale",
)
@click.option("--use-fp16-acc", default=True, help="use fp16 accumulation")
@click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm")
@click.option("--work-dir", default="./tmp", help="Work directory")
@click.option("--model-name", default="ESRGANModel", help="Model name")
def compile_esrgan(
model_path,
width,
height,
batch_size,
include_constants,
num_in_ch,
num_out_ch,
num_feat,
num_block,
num_grow_ch,
scale,
use_fp16_acc=True,
convert_conv_to_gemm=True,
work_dir="./tmp",
model_name="ESRGANModel",
):
if scale != 4:
print(
"Scale != 4 supports static shape only. Maximum value of batch_size, height and width will be used."
)

logging.getLogger().setLevel(logging.INFO)
torch.manual_seed(4896)

if detect_target().name() == "rocm":
convert_conv_to_gemm = False

if model_path.endswith(".safetensors"):
pt_model = safetensors.torch.load_file(model_path)
else:
pt_model = torch.load(model_path)

if "params_ema" in pt_model.keys():
pt_model = pt_model["params_ema"]
elif "params" in pt_model.keys():
pt_model = pt_model["params"]

rrdbnet = RRDBNet(
num_in_ch=num_in_ch,
num_out_ch=num_out_ch,
scale=scale,
num_feat=num_feat,
num_block=num_block,
num_grow_ch=num_grow_ch,
)
rrdbnet.name_parameter_tensor()

constants = map_rrdb(pt_model, scale=scale)

batch_size = IntVar(values=list(batch_size), name="batch_size")
channels = num_in_ch
height = IntVar(values=list(height), name="height")
width = IntVar(values=list(width), name="width")

image = Tensor(
shape=[batch_size, height, width, channels], name="input_pixels", is_input=True
)

Y = rrdbnet(image)
Y = mark_output(Y, "upscaled_pixels")

target = detect_target(
use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm
)
compile_model(
Y,
target,
work_dir,
model_name,
constants=constants if include_constants else None,
)


if __name__ == "__main__":
compile_esrgan()
Loading

0 comments on commit e3d3214

Please sign in to comment.