Skip to content

Commit

Permalink
fix: modnet sess
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyi-Lin committed Sep 9, 2024
1 parent 310afe0 commit 1e33240
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion hivision/creator/human_matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def extract_human_modnet_photographic_portrait_matting(ctx: Context):
:param ctx: 上下文
"""
# 抠图
matting_image = get_modnet_matting(
matting_image = get_modnet_matting_photographic_portrait_matting(
ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
)
# 修复抠图
Expand Down Expand Up @@ -221,6 +221,38 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
return output_image


def get_modnet_matting_photographic_portrait_matting(
input_image, checkpoint_path, ref_size=512
):
global MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS

if not os.path.exists(checkpoint_path):
print(f"Checkpoint file not found: {checkpoint_path}")
return None

if MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS is None:
MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = load_onnx_model(
checkpoint_path, set_cpu=True
)

input_name = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.get_inputs()[0].name
output_name = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.get_outputs()[0].name

im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size)

matte = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.run(
[output_name], {input_name: im}
)
matte = (matte[0] * 255).astype("uint8")
matte = np.squeeze(matte)
mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
b, g, r = cv2.split(np.uint8(input_image))

output_image = cv2.merge((b, g, r, mask))

return output_image


def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
global RMBG_SESS

Expand Down

0 comments on commit 1e33240

Please sign in to comment.