From 1e3324028eb629c62582d8622f821c92076d0954 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Mon, 9 Sep 2024 13:10:33 +0800 Subject: [PATCH] fix: modnet sess --- hivision/creator/human_matting.py | 34 ++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/hivision/creator/human_matting.py b/hivision/creator/human_matting.py index 626abed6..f746fb35 100644 --- a/hivision/creator/human_matting.py +++ b/hivision/creator/human_matting.py @@ -95,7 +95,7 @@ def extract_human_modnet_photographic_portrait_matting(ctx: Context): :param ctx: 上下文 """ # 抠图 - matting_image = get_modnet_matting( + matting_image = get_modnet_matting_photographic_portrait_matting( ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"] ) # 修复抠图 @@ -221,6 +221,38 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512): return output_image +def get_modnet_matting_photographic_portrait_matting( + input_image, checkpoint_path, ref_size=512 +): + global MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS + + if not os.path.exists(checkpoint_path): + print(f"Checkpoint file not found: {checkpoint_path}") + return None + + if MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS is None: + MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = load_onnx_model( + checkpoint_path, set_cpu=True + ) + + input_name = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.get_inputs()[0].name + output_name = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.get_outputs()[0].name + + im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size) + + matte = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.run( + [output_name], {input_name: im} + ) + matte = (matte[0] * 255).astype("uint8") + matte = np.squeeze(matte) + mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA) + b, g, r = cv2.split(np.uint8(input_image)) + + output_image = cv2.merge((b, g, r, mask)) + + return output_image + + def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024): global RMBG_SESS