diff --git a/hivision/creator/human_matting.py b/hivision/creator/human_matting.py index dfa71f9f..fdad0ba1 100644 --- a/hivision/creator/human_matting.py +++ b/hivision/creator/human_matting.py @@ -25,6 +25,11 @@ "weights", "modnet_photographic_portrait_matting.onnx", ), + "mnn_hivision_modnet": os.path.join( + os.path.dirname(__file__), + "weights", + "mnn_hivision_modnet.mnn", + ) } @@ -40,6 +45,32 @@ def extract_human(ctx: Context): ctx.matting_image = ctx.processing_image.copy() +def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512): + try: + import MNN.expr as expr + import MNN.nn as nn + except ImportError as e: + raise ImportError("MNN模块未安装或导入错误。请确保已安装MNN库,使用命令 'pip install mnn' 安装。") from e + config = {} + config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理 + config['backend'] = 0 # CPU + config['numThread'] = 4 # 线程数 + im, width, length = read_modnet_image(input_image, ref_size=512) + rt = nn.create_runtime_manager((config,)) + net = nn.load_module_from_file(checkpoint_path, ['input1'], ['output1'], runtime_manager=rt) + input_var = expr.convert(im, expr.NCHW) + output_var = net.forward(input_var) + matte = expr.convert(output_var, expr.NCHW) + matte = matte.read()#var转换为np + matte = (matte * 255).astype("uint8") + matte = np.squeeze(matte) + mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA) + b, g, r = cv2.split(np.uint8(input_image)) + + output_image = cv2.merge((b, g, r, mask)) + + return output_image + def extract_human_modnet_photographic_portrait_matting(ctx: Context): """ 人像抠图 @@ -53,6 +84,11 @@ def extract_human_modnet_photographic_portrait_matting(ctx: Context): ctx.processing_image = hollow_out_fix(matting_image) ctx.matting_image = ctx.processing_image.copy() +def extract_human_mnn_modnet(ctx: Context): + matting_image = get_mnn_modnet_matting(ctx.processing_image, WEIGHTS["mnn_hivision_modnet"]) + ctx.processing_image = hollow_out_fix(matting_image) + ctx.matting_image = ctx.processing_image.copy() + def hollow_out_fix(src: np.ndarray) -> np.ndarray: """ diff --git a/inference.py b/inference.py index e15764a1..56569286 100644 --- a/inference.py +++ b/inference.py @@ -12,6 +12,7 @@ from hivision.creator.human_matting import ( extract_human_modnet_photographic_portrait_matting, extract_human, + extract_human_mnn_modnet, ) parser = argparse.ArgumentParser(description="HivisionIDPhotos 证件照制作推理程序。") @@ -24,7 +25,7 @@ "add_background", "generate_layout_photos", ] -MATTING_MODEL = ["hivision_modnet", "modnet_photographic_portrait_matting"] +MATTING_MODEL = ["hivision_modnet", "modnet_photographic_portrait_matting", "mnn_hivision_modnet"] RENDER = [0, 1, 2] parser.add_argument( @@ -64,6 +65,8 @@ creator.matting_handler = extract_human elif args.matting_model == "modnet_photographic_portrait_matting": creator.matting_handler = extract_human_modnet_photographic_portrait_matting +elif args.matting_model == "mnn_hivision_modnet": + creator.matting_handler = extract_human_mnn_modnet root_dir = os.path.dirname(os.path.abspath(__file__)) input_image = cv2.imread(args.input_image_dir, cv2.IMREAD_UNCHANGED) diff --git a/requirements.txt b/requirements.txt index 8d1f2808..c469676e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ opencv-python>=4.8.1.78 onnxruntime>=1.15.0 numpy<=1.26.4 requests -mtcnn-runtime \ No newline at end of file +mtcnn-runtime