diff --git a/README.md b/README.md index 87fa639c..d030c421 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,15 @@ python app.py # 🚀 Python 推理 +核心参数: + +- `-i`: 输入图像路径 +- `-o`: 保存图像路径 +- `-t`: 推理类型,有idphoto、human_matting、add_background、generate_layout_photos可选 +- `--matting_model`: 人像抠图模型权重选择,可选`hivision_modnet`、`modnet_photographic_portrait_matting` + +更多参数可通过`python inference.py --help`查看 + ## 1. 证件照制作 输入 1 张照片,获得 1 张标准证件照和 1 张高清证件照的 4 通道透明 png @@ -119,15 +128,21 @@ python app.py python inference.py -i demo/images/test.jpg -o ./idphoto.png --height 413 --width 295 ``` -## 2. 增加底色 +## 2. 人像抠图 + +```python +python inference.py -i -t human_matting demo/images/test.jpg -o ./idphoto_matting.png --matting_model hivision_modnet +``` + +## 3. 透明图增加底色 输入 1 张 4 通道透明 png,获得 1 张增加了底色的图像) ```python -python inference.py -t add_background -i ./idphoto.png -o ./idhoto_ab.jpg -c 000000 -k 30 +python inference.py -t add_background -i ./idphoto.png -o ./idhoto_ab.jpg -c 4f83ce -k 30 -r 1 ``` -## 3. 得到六寸排版照 +## 4. 得到六寸排版照 输入 1 张 3 通道照片,获得 1 张六寸排版照 diff --git a/inference.py b/inference.py index 53c71969..89f52ad3 100644 --- a/inference.py +++ b/inference.py @@ -10,36 +10,66 @@ generate_layout_photo, generate_layout_image, ) +from hivision.creator.human_matting import ( + extract_human_modnet_photographic_portrait_matting, + extract_human, +) parser = argparse.ArgumentParser(description="HivisionIDPhotos 证件照制作推理程序。") creator = IDCreator() +INFERENCE_TYPE = [ + "idphoto", + "human_matting", + "add_background", + "generate_layout_photos", +] +MATTING_MODEL = ["hivision_modnet", "modnet_photographic_portrait_matting"] +RENDER = [0, 1, 2] + parser.add_argument( "-t", - "--type", - help="请求 API 的种类,有 idphoto、add_background 和 generate_layout_photos 可选", + help="请求 API 的种类", + choices=INFERENCE_TYPE, default="idphoto", ) -parser.add_argument("-i", "--input_image_dir", help="输入图像路径", required=True) -parser.add_argument("-o", "--output_image_dir", help="保存图像路径", required=True) +parser.add_argument("-i", help="输入图像路径", required=True) +parser.add_argument("-o", help="保存图像路径", required=True) parser.add_argument("--height", help="证件照尺寸-高", default=413) parser.add_argument("--width", help="证件照尺寸-宽", default=295) parser.add_argument("-c", "--color", help="证件照背景色", default="638cce") parser.add_argument( "-k", "--kb", help="输出照片的 KB 值,仅对换底和制作排版照生效", default=None ) +parser.add_argument( + "--matting_model", + help="抠图模型权重", + default="hivision_modnet", + choices=MATTING_MODEL, +) +parser.add_argument( + "-r", + "--render", + type=int, + help="底色合成的模式,有 0:纯色、1:上下渐变、2:中心渐变 可选", + choices=RENDER, + default=0, +) args = parser.parse_args() -root_dir = os.path.dirname(os.path.abspath(__file__)) +# ------------------- 人像抠图模型选择 ------------------- +if args.matting_model == "hivision_modnet": + creator.matting_handler = extract_human +elif args.matting_model == "modnet_photographic_portrait_matting": + creator.matting_handler = extract_human_modnet_photographic_portrait_matting +root_dir = os.path.dirname(os.path.abspath(__file__)) input_image = cv2.imread(args.input_image_dir, cv2.IMREAD_UNCHANGED) - # 如果模式是生成证件照 if args.type == "idphoto": - # 将字符串转为元组 size = (int(args.height), int(args.width)) try: @@ -55,15 +85,24 @@ new_file_name = file_name + "_hd" + file_extension cv2.imwrite(new_file_name, result.hd) +# 如果模式是人像抠图 +elif args.type == "human_matting": + result = creator(input_image, change_bg_only=True) + cv2.imwrite(args.output_image_dir, result.hd) + # 如果模式是添加背景 elif args.type == "add_background": + render_choice = ["pure_color", "updown_gradient", "center_gradient"] + # 将字符串转为元组 color = hex_to_rgb(args.color) # 将元祖的 0 和 2 号数字交换 color = (color[2], color[1], color[0]) - result_image = add_background(input_image, bgr=color) + result_image = add_background( + input_image, bgr=color, mode=render_choice[args.render] + ) result_image = result_image.astype(np.uint8) if args.kb: