Skip to content

Commit

Permalink
Add support for parallel processing with multiple GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
zackees committed Apr 25, 2024
1 parent 8f1330f commit 11f869e
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 7 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,7 @@ src/zcmds/cmds/common/losslesscut_bins/**
.aider*

!.aider.conf.yml
!.aiderignore
!.aiderignore

tests/test_data/rembg-nobackground.webm

67 changes: 61 additions & 6 deletions src/zcmds/cmds/common/removebackground.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import concurrent.futures
import os
import shutil
import subprocess
Expand Down Expand Up @@ -99,6 +100,7 @@ def video_remove_background(
fps_override: Optional[float] = None,
model: str = MODEL,
keep_files: bool = False,
exposed_gpus: Optional[list[int]] = None,
) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
vidinfo: VidInfo = get_video_info(video_path)
Expand All @@ -109,12 +111,56 @@ def video_remove_background(
raise OSError("Error converting video to images")
print(f"Images saved to {output_dir}")

final_output_dir = output_dir / "video"
final_output_dir.mkdir(parents=True, exist_ok=True)
cmd = f'rembg p -a -ae 15 --post-process-mask -m {model} "{output_dir}" "{final_output_dir}"'
print(f"Running: {cmd}")
os.system(cmd)
print(f"Images with background removed saved to {final_output_dir}")
if exposed_gpus is None or len(exposed_gpus) == 1:
final_output_dir = output_dir / "video"
final_output_dir.mkdir(parents=True, exist_ok=True)
cmd = f'rembg p -a -ae 15 --post-process-mask -m {model} "{output_dir}" "{final_output_dir}"'
print(f"Running: {cmd}")
os.system(cmd)
print(f"Images with background removed saved to {final_output_dir}")
else:
# Split the images into subfolders for parallel processing
num_gpus = len(exposed_gpus)
img_files = list(output_dir.glob("*.png"))
img_files.sort()
chunk_size = (len(img_files) + num_gpus - 1) // num_gpus
img_chunks = [
img_files[i : i + chunk_size] for i in range(0, len(img_files), chunk_size)
]

def process_chunk(chunk, gpu_id):
chunk_dir = output_dir / str(gpu_id)
chunk_dir.mkdir(parents=True, exist_ok=True)
for img in chunk:
shutil.move(str(img), str(chunk_dir / img.name))

final_output_dir = chunk_dir / "video"
final_output_dir.mkdir(parents=True, exist_ok=True)
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
cmd = f'rembg p -a -ae 15 --post-process-mask -m {model} "{chunk_dir}" "{final_output_dir}"'
print(f"Running: {cmd}")
# os.system(cmd)
import subprocess

subprocess.run(cmd, shell=True, env=env, check=True)
print(f"Images with background removed saved to {final_output_dir}")
print()

with concurrent.futures.ThreadPoolExecutor(max_workers=num_gpus) as executor:
futures = [
executor.submit(process_chunk, chunk, gpu_id)
for gpu_id, chunk in enumerate(img_chunks)
]
concurrent.futures.wait(futures)

# Merge the processed images back into the main video directory
final_output_dir = output_dir / "video"
final_output_dir.mkdir(parents=True, exist_ok=True)
for gpu_id in range(num_gpus):
chunk_output_dir = output_dir / str(gpu_id) / "video"
for img in chunk_output_dir.glob("*.png"):
shutil.move(str(img), str(final_output_dir / img.name))

fps: float = fps_override if fps_override else vidinfo.fps
out_vid_path = Path(str(video_path.with_suffix("")) + f"-nobg-{model}.webm")
Expand Down Expand Up @@ -198,6 +244,12 @@ def parse_args() -> argparse.Namespace:
choices=MODEL_CHOICES,
help=f"Model to use (default: {MODEL}, choices: {MODEL_CHOICES})",
)
parser.add_argument(
"--gpu-count",
type=int,
default=1,
help="Number of GPUs to use for parallel processing (default: 1)",
)

return parser.parse_args()

Expand All @@ -220,6 +272,7 @@ def cli() -> int:
fps_override=args.fps,
keep_files=args.keep_files,
model=args.model,
exposed_gpus=[int(i) for i in range(args.gpu_count)],
)
return 0
cmd = f'rembg -a -ae 15 --post-process-mask -m {args.model} i "{args.file}"'
Expand Down Expand Up @@ -257,6 +310,8 @@ def unit_test() -> None:
_cd_to_project_root()
test_mp4 = test_data()
sys.argv.append(test_mp4)
sys.argv.append("--gpu-count")
sys.argv.append("2")
# u2net_human_seg
# sys.argv.append("--model")
# sys.argv.append("isnet-general-use")
Expand Down
Binary file removed tests/test_data/rembg-nobackground.webm
Binary file not shown.

0 comments on commit 11f869e

Please sign in to comment.