Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add controlnet module #79

Merged
merged 4 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions bmf/demo/controlnet/ReadMe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# BMF ControlNet Demo

This demo shows how to use ControlNet+StableDiffusion to generate image from text prompts in BMF. We use a performance-optimized ControlNet [implementation](https://github.com/NVIDIA/trt-samples-for-hackathon-cn/tree/master/Hackathon2023/controlnet). This implementation accelerates the canny2image app in the official ControlNet repo.

You need to compile or install bmf before running the demo. Please refer to the [document](https://babitmf.github.io/docs/bmf/getting_started_yourself/install/) on how to build or install bmf.

### Generate TensorRT Engine

First we need to put the ControlNet code in the demo directory. This repo contains lots of samples of TensorRT, the ControlNet implementation we need in located in `trt-samples-for-hackathon-cn/Hackathon2023/controlnet`
```Bash
git clone https://github.com/NVIDIA/trt-samples-for-hackathon-cn.git
# copy the controlnet implementation to the demo path for simplicity
cp -r trt-samples-for-hackathon-cn/Hackathon2023/controlnet bmf/demo/controlnet
```

Download the state dict from HuggingFace and generate the TensorRT engine. You need to change the state dict path in `controlnet/export_onnx.py:19` to where you put the file. Then run `preprocess.sh` to build the TensorRT engine.
```Bash
cd bmf/demo/controlnet/controlnet/models
wget https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/control_sd15_canny.pth
# Change the path to './models/control_sd15_canny.pth' in controlnet/export_onnx.py:19
cd .. # go back to the controlnet directory
bash preprocess.sh
```

Once the script runs successfully, several `.trt` files will be generated, which are the TensorRT engines. Copy the generated TensorRT engines to the directory of the demo and run the ControlNet pipeline using the `test_controlnet.py` script
```Bash
mv *.trt path/to/the/demo
cd path/to/the/demo
python test_controlnet.py
```
The pipeline will generate a new image based on the input image and prompt.
84 changes: 84 additions & 0 deletions bmf/demo/controlnet/controlnet_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import sys
import random
from typing import List, Optional
import numpy as np
import pdb

from bmf import *
import bmf.hml.hmp as mp
sys.path.append('./controlnet')
from canny2image_TRT import hackathon

class controlnet_module(Module):
def __init__(self, node, option=None):
self.node_ = node
self.eof_received_ = False
self.hk = hackathon()
self.hk.initialize()
self.prompt_path = './prompt.txt'
self.eof_received_ = [False, False]
self.prompt_ = None
self.frame_list_ = []
if 'path' in option.keys():
self.prompt_path = option['path']

def process(self, task):
img_queue = task.get_inputs()[0]
pmt_queue = task.get_inputs()[1]
output_queue = task.get_outputs()[0]

while not pmt_queue.empty():
pmt_pkt = pmt_queue.get()

if pmt_pkt.timestamp == Timestamp.EOF:
self.eof_received_[0] = True
else:
pmt = pmt_pkt.get(dict)
self.prompt_ = pmt

while not img_queue.empty():
in_pkt = img_queue.get()

if in_pkt.timestamp == Timestamp.EOF:
self.eof_received_[1] = True
else:
self.frame_list_.append(in_pkt.get(VideoFrame))

while self.prompt_ and len(self.frame_list_) > 0:
in_frame = self.frame_list_[0]
del self.frame_list_[0]

gen_img = self.hk.process(in_frame.cpu().frame().data()[0].numpy(),
pmt['prompt'], pmt['a_prompt'], pmt['n_prompt'],
1,
256,
20,
False,
1,
9,
2946901,
0.0,
100,
200)

rgbinfo = mp.PixelInfo(mp.PixelFormat.kPF_RGB24,
in_frame.frame().pix_info().space,
in_frame.frame().pix_info().range)
out_f = mp.Frame(mp.from_numpy(gen_img[0]), rgbinfo)
out_vf = VideoFrame(out_f)
out_vf.pts = in_frame.pts
out_vf.time_base = in_frame.time_base
out_pkt = Packet(out_vf)
out_pkt.timestamp = out_vf.pts
output_queue.put(out_pkt)

if self.eof_received_[0] and self.eof_received_[1] and len(self.frame_list_) == 0:
output_queue.put(Packet.generate_eof_packet())
Log.log_node(LogLevel.DEBUG, self.node_, 'output text stream', 'done')
task.set_timestamp(Timestamp.DONE)
return ProcessResult.OK

return ProcessResult.OK

def register_inpaint_module_info(info):
info.module_description = "ControlNet inference module"
3 changes: 3 additions & 0 deletions bmf/demo/controlnet/prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
prompt: a bird
a_prompt: best quality, extremely detailed
n_prompt: longbody, lowres, bad anatomy, bad hands, missing fingers
60 changes: 60 additions & 0 deletions bmf/demo/controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import sys

sys.path.append("../../")
import bmf

sys.path.pop()

def test():
input_video_path = "./controlnet/test_imgs/bird.png"
input_prompt_path = "./prompt.txt"
output_path = "./output.png"

graph = bmf.graph()

# dual inputs
# -------------------------------------------------------------------------
video = graph.decode({'input_path': input_video_path})
prompt = video.module('text_module', {'path': input_prompt_path})

control=bmf.module(streams=[video, prompt], module_info='controlnet_module')
control.encode(None, {'output_path': output_path}).run()

# sync mode
# from bmf import bmf_sync, Packet
# decoder = bmf_sync.sync_module("c_ffmpeg_decoder", {"input_path":"./ControlNet/test_imgs/bird.png"}, [], [0])
# prompt = bmf_sync.sync_module('text_module', {'path': './prompt.txt'}, [], [1])
# controlnet = bmf_sync.sync_module('controlnet_module', {}, [0, 1], [0])

# decoder.init()
# prompt.init()
# controlnet.init()

# img, _ = bmf_sync.process(decoder, None)
# txt, _ = bmf_sync.process(prompt, None)
# gen_img, _ = bmf_sync.process(controlnet, {0: img[0], 1: txt[1]})
# --------------------------------------------------------------------------

# video = graph.decode({
# "input_path": input_video_path,
# # "video_params": {
# # "hwaccel": "cuda",
# # # "pix_fmt": "yuv420p",
# # }
# })
# (video['video']
# .module('controlnet', {

# })
# .encode(
# None, {
# "output_path": output_path,
# "video_params": {
# "codec": "png",
# # "pix_fmt": "cuda",
# }
# }).run())


if __name__ == '__main__':
test()
43 changes: 43 additions & 0 deletions bmf/demo/controlnet/text_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import sys
import random
from typing import List, Optional
import pdb

from bmf import *
import bmf.hml.hmp as mp

class text_module(Module):
def __init__(self, node, option=None):
self.node_ = node
self.eof_received_ = False
self.prompt_path = './prompt.txt'
if 'path' in option.keys():
self.prompt_path = option['path']

def process(self, task):
input_packets = task.get_inputs()[0]
output_queue = task.get_outputs()[0]

while not input_packets.empty():
pkt = input_packets.get()
if pkt.timestamp == Timestamp.EOF:
output_queue.put(Packet.generate_eof_packet())
Log.log_node(LogLevel.DEBUG, self.node_, 'output text stream', 'done')
task.set_timestamp(Timestamp.DONE)
return ProcessResult.OK

prompt_dict = dict()
with open(self.prompt_path) as f:
for line in f:
pk, pt = line.partition(":")[::2]
prompt_dict[pk] = pt

out_pkt = Packet(prompt_dict)
out_pkt.timestamp = 0
output_queue.put(out_pkt)
# self.eof_received_ = True

return ProcessResult.OK

def register_inpaint_module_info(info):
info.module_description = "Text file IO module"