-
Notifications
You must be signed in to change notification settings - Fork 376
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: This PR adds [ESRGAN](https://github.com/xinntao/Real-ESRGAN) upscaler example. It is a draft for now, see Notes. Tested with [`RealESRGAN_x4plus`](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth). I did not find safetensors versions of these models, but it is supported. Compile: ``` python compile.py --model-path "RealESRGAN_x4plus.safetensors" --include-constants True ``` Demo: ``` python demo.py --module-path ./tmp/ESRGANModel/test.so --input-image-path "nvidiafu.png" --output-image-path "nvidiafu_upscaled.png" ``` Input image:  Upscaled image:  ## Notes * This depends on #846. * Demo pipeline is basic - no batch support, no alpha channel support. * Dynamic batch is untested but should work. * ESRGAN does support alpha channel, it is just not implemented in the demo pipeline. * Dynamic shape (height, width) is supported. * Correct scale must be given to `demo.py`. * Model arch can be configured using compile.py arguments `--num-in-ch`, `--num-out-ch`, `--num-feat`, `--num-block`, `--num-grow-ch`, `--scale`. * `--num-block` and `--scale` are most likely to be changed. * e.g. `--num-block 6` for `RealESRGAN_x4plus_anime_6B` or `--scale 2` for `RealESRGAN_x2plus`. * Scale != 4 models e.g. `RealESRGAN_x2plus` support static shape only. * Needs documentation. After dependency is merged and documentation I will mark this PR as ready. Demo pipeline could be improved, but it is just a demo, full feature implementation is left to the developer integrating this module. Pull Request resolved: #847 Reviewed By: sgrigory Differential Revision: D48425425 Pulled By: aakhundov fbshipit-source-id: 38602581fd8991540240d8c58d274af905803876
- Loading branch information
1 parent
179bfb2
commit e3d3214
Showing
4 changed files
with
633 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
## ESRGAN Example | ||
|
||
In this example, we show how to build fast AIT modules for ESRGAN models, and benchmark/run them. | ||
|
||
### Build Dependencies | ||
|
||
First, clone, build, and install AITemplate [per the README instructions](https://github.com/facebookincubator/AITemplate#clone-the-code). | ||
|
||
This AIT ESRGAN example depends on `torch`, `click` and optionally `safetensors`. You could install them using `pip`. | ||
|
||
### Download the ESRGAN model | ||
|
||
We have tested the following ESRGAN models. | ||
|
||
[RealESRGAN_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth) | ||
|
||
Model architecture: | ||
``` | ||
num_in_ch: 3, | ||
num_out_ch: 3, | ||
num_feat: 64, | ||
num_block: 23, | ||
num_grow_ch: 32, | ||
scale: 4, | ||
``` | ||
|
||
|
||
[RealESRGAN_x4plus_anime_6B](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) | ||
|
||
Model architecture: | ||
``` | ||
num_in_ch: 3, | ||
num_out_ch: 3, | ||
num_feat: 64, | ||
num_block: 6, | ||
num_grow_ch: 32, | ||
scale: 4, | ||
``` | ||
|
||
|
||
[RealESRGAN_x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth) | ||
|
||
Model architecture: | ||
``` | ||
num_in_ch: 3, | ||
num_out_ch: 3, | ||
num_feat: 64, | ||
num_block: 23, | ||
num_grow_ch: 32, | ||
scale: 2, | ||
``` | ||
|
||
A database of ESRGAN models can be found [here](https://upscale.wiki/wiki/Model_Database). | ||
|
||
Safetensors versions are supported. | ||
|
||
|
||
### Build AIT modules for ESRGAN | ||
|
||
Build the AIT modules by running `compile.py`. | ||
|
||
``` | ||
Usage: compile.py [OPTIONS] | ||
Options: | ||
--model-path TEXT model path. supports torch or safetensors | ||
--width <INTEGER INTEGER>... Minimum and maximum width | ||
--height <INTEGER INTEGER>... Minimum and maximum height | ||
--batch-size <INTEGER INTEGER>... | ||
Minimum and maximum batch size | ||
--include-constants BOOLEAN include constants (model weights) with | ||
compiled model | ||
--num-in-ch INTEGER Number of in channels | ||
--num-out-ch INTEGER Number of out channels | ||
--num-feat INTEGER Number of intermediate features | ||
--num-block INTEGER Number of RRDB layers | ||
--num-grow-ch INTEGER Number of channels for each growth | ||
--scale INTEGER Scale | ||
--use-fp16-acc BOOLEAN use fp16 accumulation | ||
--convert-conv-to-gemm BOOLEAN convert 1x1 conv to gemm | ||
--work-dir TEXT Work directory | ||
--model-name TEXT Model name | ||
--help Show this message and exit. | ||
``` | ||
|
||
Use `--num-in-ch`, `--num-out-ch`, `--num-feat`, `--num-block`, `--num-grow-ch`, `--scale` options to set the ESRGAN model architecture. The default values are for `RealESRGAN_x4plus` architecture. | ||
|
||
`--width` and `--height` require a minimum and maximum value, the compiled module supports the range of resolutions. However, with 2x model architecture, only static shape is supported, the maximum value for each dimension is used. Defaults are `64` and `1024`. | ||
|
||
`--batch-size` is supported for 4x model architecture, provide minimum and maximum values. Default is `1 1`. | ||
|
||
Use `--include-constants False` to compile the module without model weights. | ||
|
||
AIT modules are compatible with all ESRGAN models with the same model architecture. This can simplify deployment by compiling a module without model weights then applying weights at runtime by using AIT mapped weights (see `map_rrdb` in `./modeling/rrdbnet.py`) with the module's `set_many_constants_with_tensors`. | ||
|
||
In our tests an ESRGAN module compiled with weights is approximately `~38MB`, and `~6.5MB` without. | ||
|
||
Examples: | ||
|
||
``` | ||
python compile.py --model-path "RealESRGAN_x4plus.safetensors" | ||
``` | ||
|
||
``` | ||
python compile.py --model-path "RealESRGAN_x4plus_anime_6B.pth" --num-block 6 --model-name RealESRGAN_x4plus_anime_6B | ||
``` | ||
|
||
``` | ||
python compile.py --model-path "RealESRGAN_x2plus.pth" --scale 2 --model-name RealESRGAN_x2plus --width 512 512 --height 512 512 | ||
``` | ||
|
||
|
||
#### Multi-GPU profiling | ||
AIT needs to do profiling to select the best algorithms for CUTLASS and CK. | ||
To enable multiple GPUs for profiling, use the environment variable `CUDA_VISIBLE_DEVICES` on NVIDIA platform and `HIP_VISIBLE_DEVICES` on AMD platform. | ||
|
||
|
||
### Run Models | ||
|
||
`demo.py` provides example usage of ESRGAN modules. | ||
|
||
``` | ||
Usage: demo.py [OPTIONS] | ||
Options: | ||
--module-path TEXT the AIT module path | ||
--input-image-path TEXT path to input image | ||
--output-image-path TEXT path to output image | ||
--scale INTEGER Scale of ESRGAN model | ||
--help Show this message and exit. | ||
``` | ||
|
||
`--scale` must match the scale of the model architecture. | ||
|
||
Limitations: | ||
* Demo does not support multiple images/batch size. | ||
* Demo does not support images with alpha channel. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import logging | ||
|
||
import click | ||
import safetensors.torch | ||
import torch | ||
|
||
from aitemplate.compiler import compile_model | ||
from aitemplate.frontend import IntVar, Tensor | ||
from aitemplate.testing import detect_target | ||
|
||
from modeling.rrdbnet import map_rrdb, mark_output, RRDBNet | ||
|
||
|
||
@click.command() | ||
@click.option( | ||
"--model-path", | ||
default="RealESRGAN_x4plus.safetensors", | ||
help="model path. supports torch or safetensors", | ||
) | ||
@click.option( | ||
"--width", | ||
default=(64, 1024), | ||
type=(int, int), | ||
nargs=2, | ||
help="Minimum and maximum width", | ||
) | ||
@click.option( | ||
"--height", | ||
default=(64, 1024), | ||
type=(int, int), | ||
nargs=2, | ||
help="Minimum and maximum height", | ||
) | ||
@click.option( | ||
"--batch-size", | ||
default=(1, 1), | ||
type=(int, int), | ||
nargs=2, | ||
help="Minimum and maximum batch size", | ||
) | ||
@click.option( | ||
"--include-constants", | ||
default=True, | ||
type=bool, | ||
help="include constants (model weights) with compiled model", | ||
) | ||
@click.option( | ||
"--num-in-ch", | ||
default=3, | ||
type=int, | ||
help="Number of in channels", | ||
) | ||
@click.option( | ||
"--num-out-ch", | ||
default=3, | ||
type=int, | ||
help="Number of out channels", | ||
) | ||
@click.option( | ||
"--num-feat", | ||
default=64, | ||
type=int, | ||
help="Number of intermediate features", | ||
) | ||
@click.option( | ||
"--num-block", | ||
default=23, | ||
type=int, | ||
help="Number of RRDB layers", | ||
) | ||
@click.option( | ||
"--num-grow-ch", | ||
default=32, | ||
type=int, | ||
help="Number of channels for each growth", | ||
) | ||
@click.option( | ||
"--scale", | ||
default=4, | ||
type=int, | ||
help="Scale", | ||
) | ||
@click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") | ||
@click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") | ||
@click.option("--work-dir", default="./tmp", help="Work directory") | ||
@click.option("--model-name", default="ESRGANModel", help="Model name") | ||
def compile_esrgan( | ||
model_path, | ||
width, | ||
height, | ||
batch_size, | ||
include_constants, | ||
num_in_ch, | ||
num_out_ch, | ||
num_feat, | ||
num_block, | ||
num_grow_ch, | ||
scale, | ||
use_fp16_acc=True, | ||
convert_conv_to_gemm=True, | ||
work_dir="./tmp", | ||
model_name="ESRGANModel", | ||
): | ||
if scale != 4: | ||
print( | ||
"Scale != 4 supports static shape only. Maximum value of batch_size, height and width will be used." | ||
) | ||
|
||
logging.getLogger().setLevel(logging.INFO) | ||
torch.manual_seed(4896) | ||
|
||
if detect_target().name() == "rocm": | ||
convert_conv_to_gemm = False | ||
|
||
if model_path.endswith(".safetensors"): | ||
pt_model = safetensors.torch.load_file(model_path) | ||
else: | ||
pt_model = torch.load(model_path) | ||
|
||
if "params_ema" in pt_model.keys(): | ||
pt_model = pt_model["params_ema"] | ||
elif "params" in pt_model.keys(): | ||
pt_model = pt_model["params"] | ||
|
||
rrdbnet = RRDBNet( | ||
num_in_ch=num_in_ch, | ||
num_out_ch=num_out_ch, | ||
scale=scale, | ||
num_feat=num_feat, | ||
num_block=num_block, | ||
num_grow_ch=num_grow_ch, | ||
) | ||
rrdbnet.name_parameter_tensor() | ||
|
||
constants = map_rrdb(pt_model, scale=scale) | ||
|
||
batch_size = IntVar(values=list(batch_size), name="batch_size") | ||
channels = num_in_ch | ||
height = IntVar(values=list(height), name="height") | ||
width = IntVar(values=list(width), name="width") | ||
|
||
image = Tensor( | ||
shape=[batch_size, height, width, channels], name="input_pixels", is_input=True | ||
) | ||
|
||
Y = rrdbnet(image) | ||
Y = mark_output(Y, "upscaled_pixels") | ||
|
||
target = detect_target( | ||
use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm | ||
) | ||
compile_model( | ||
Y, | ||
target, | ||
work_dir, | ||
model_name, | ||
constants=constants if include_constants else None, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
compile_esrgan() |
Oops, something went wrong.