diff --git a/README.md b/README.md
index da3e3391..0b519d4d 100644
--- a/README.md
+++ b/README.md
@@ -55,6 +55,7 @@
- 在线体验: [![SwanHub Demo](https://img.shields.io/static/v1?label=Demo&message=SwanHub%20Demo&color=blue)](https://swanhub.co/ZeYiLin/HivisionIDPhotos/demo)、[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/TheEeeeLin/HivisionIDPhotos)、[![][modelscope-shield]][modelscope-link]
+- 2024.09.22: Gradio Demo增加**野兽模式**,可设置内存加载策略
- 2024.09.18: Gradio Demo增加**分享模版照**功能、增加**美式证件照**背景选项
- 2024.09.17: Gradio Demo增加**自定义底色-HEX输入**功能 | **(社区贡献)C++版本** - [HivisionIDPhotos-cpp](https://github.com/zjkhahah/HivisionIDPhotos-cpp) 贡献 by [zjkhahah](https://github.com/zjkhahah)
- 2024.09.16: Gradio Demo增加**人脸旋转对齐**功能,自定义尺寸输入支持**毫米**单位
@@ -62,7 +63,6 @@
- 2024.09.12: Gradio Demo增加**美白**功能 | API接口增加**加水印**、**设置照片KB值大小**、**证件照裁切**
- 2024.09.11: Gradio Demo增加**透明图显示与下载**功能
- 2024.09.10: 增加新的**人脸检测模型** Retinaface-resnet50,以稍弱于mtcnn的速度换取更高的检测精度,推荐使用
-- 2024.09.09: 增加新的**抠图模型** [BiRefNet-v1-lite](https://github.com/ZhengPeng7/BiRefNet) | Gradio增加**高级参数设置**和**水印**选项卡
@@ -319,13 +319,15 @@ docker compose up -d
|--|--|--|--|
| FACE_PLUS_API_KEY | 可选 | 这是你在 Face++ 控制台申请的 API 密钥 | `7-fZStDJ····` |
| FACE_PLUS_API_SECRET | 可选 | Face++ API密钥对应的Secret | `VTee824E····` |
+| RUN_MODE | 可选 | 运行模式,可选值为`beast`(野兽模式)。野兽模式下人脸检测和抠图模型将不释放内存,从而获得更快的二次推理速度。建议内存16GB以上尝试。 | `beast` |
docker使用环境变量示例:
```bash
docker run -d -p 7860:7860 \
-e FACE_PLUS_API_KEY=7-fZStDJ···· \
-e FACE_PLUS_API_SECRET=VTee824E···· \
- linzeyi/hivision_idphotos
+ -e RUN_MODE=beast \
+ linzeyi/hivision_idphotos
```
diff --git a/app.py b/app.py
index 813cca26..ebb7db92 100644
--- a/app.py
+++ b/app.py
@@ -65,6 +65,11 @@
FACE_DETECT_MODELS_CHOICE,
LANGUAGE,
)
+
+ # 如果RUN_MODE是Beast,打印已开启野兽模式
+ if os.getenv("RUN_MODE") == "beast":
+ print("[Beast mode activated.] 已开启野兽模式。")
+
demo.launch(
server_name=args.host,
server_port=args.port,
diff --git a/hivision/creator/face_detector.py b/hivision/creator/face_detector.py
index 03088125..126896a9 100644
--- a/hivision/creator/face_detector.py
+++ b/hivision/creator/face_detector.py
@@ -213,3 +213,7 @@ def detect_face_retinaface(ctx: Context):
dx = right_eye[0] - left_eye[0]
roll_angle = np.degrees(np.arctan2(dy, dx))
ctx.face["roll_angle"] = roll_angle
+
+ # 如果RUN_MODE不是野兽模式,则释放模型
+ if os.getenv("RUN_MODE") == "beast":
+ RETINAFCE_SESS = None
\ No newline at end of file
diff --git a/hivision/creator/human_matting.py b/hivision/creator/human_matting.py
index 5027ef99..57b0d47d 100644
--- a/hivision/creator/human_matting.py
+++ b/hivision/creator/human_matting.py
@@ -201,6 +201,7 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
print(f"Checkpoint file not found: {checkpoint_path}")
return None
+ # 如果RUN_MODE不是野兽模式,则不加载模型
if HIVISION_MODNET_SESS is None:
HIVISION_MODNET_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
@@ -216,6 +217,10 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
b, g, r = cv2.split(np.uint8(input_image))
output_image = cv2.merge((b, g, r, mask))
+
+ # 如果RUN_MODE不是野兽模式,则释放模型
+ if os.getenv("RUN_MODE") != "beast":
+ HIVISION_MODNET_SESS = None
return output_image
@@ -229,6 +234,7 @@ def get_modnet_matting_photographic_portrait_matting(
print(f"Checkpoint file not found: {checkpoint_path}")
return None
+ # 如果RUN_MODE不是野兽模式,则不加载模型
if MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS is None:
MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = load_onnx_model(
checkpoint_path, set_cpu=True
@@ -248,6 +254,10 @@ def get_modnet_matting_photographic_portrait_matting(
b, g, r = cv2.split(np.uint8(input_image))
output_image = cv2.merge((b, g, r, mask))
+
+ # 如果RUN_MODE不是野兽模式,则释放模型
+ if os.getenv("RUN_MODE") != "beast":
+ MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = None
return output_image
@@ -297,6 +307,10 @@ def resize_rmbg_image(image):
# Paste the mask on the original image
new_im = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
+
+ # 如果RUN_MODE不是野兽模式,则释放模型
+ if os.getenv("RUN_MODE") != "beast":
+ RMBG_SESS = None
return np.array(new_im)
@@ -362,8 +376,9 @@ def transform_image(image):
# 记录加载onnx模型的开始时间
load_start_time = time()
+ # 如果RUN_MODE不是野兽模式,则不加载模型
if BIREFNET_V1_LITE_SESS is None:
- print("首次加载birefnet-v1-lite模型...")
+ # print("首次加载birefnet-v1-lite模型...")
if ONNX_DEVICE == "GPU":
print("onnxruntime-gpu已安装,尝试使用CUDA加载模型")
try:
@@ -405,5 +420,9 @@ def transform_image(image):
# Paste the mask on the original image
new_im = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
+
+ # 如果RUN_MODE不是野兽模式,则释放模型
+ if os.getenv("RUN_MODE") != "beast":
+ BIREFNET_V1_LITE_SESS = None
return np.array(new_im)