Skip to content

Commit

Permalink
fix: onnxruntime-gpu
Browse files Browse the repository at this point in the history
修复了在onnxruntime gpu版本下,运行会报错的问题
  • Loading branch information
Zeyi-Lin committed Sep 8, 2024
1 parent c58e826 commit f6bb43e
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions hivision/creator/human_matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,35 @@
"rmbg-1.4": os.path.join(os.path.dirname(__file__), "weights", "rmbg-1.4.onnx"),
}

ONNX_DEVICE = (
"CUDAExecutionProvider"
if onnxruntime.get_device() == "GPU"
else "CPUExecutionProvider"
)


def load_onnx_model(checkpoint_path):
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if ONNX_DEVICE == "CUDAExecutionProvider"
else ["CPUExecutionProvider"]
)

try:
sess = onnxruntime.InferenceSession(checkpoint_path, providers=providers)
except Exception as e:
if ONNX_DEVICE == "CUDAExecutionProvider":
print(f"Failed to load model with CUDAExecutionProvider: {e}")
print("Falling back to CPUExecutionProvider")
# 尝试使用CPU加载模型
sess = onnxruntime.InferenceSession(
checkpoint_path, providers=["CPUExecutionProvider"]
)
else:
raise e # 如果是CPU执行失败,重新抛出异常

return sess


def extract_human(ctx: Context):
"""
Expand Down Expand Up @@ -140,9 +169,11 @@ def read_modnet_image(input_image, ref_size=512):


def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
# global sess
# if sess is None:
sess = onnxruntime.InferenceSession(checkpoint_path)
if not os.path.exists(checkpoint_path):
print(f"Checkpoint file not found: {checkpoint_path}")
return None

sess = load_onnx_model(checkpoint_path)

input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
Expand All @@ -161,13 +192,17 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512):


def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
if not os.path.exists(checkpoint_path):
print(f"Checkpoint file not found: {checkpoint_path}")
return None

def resize_rmbg_image(image):
image = image.convert("RGB")
model_input_size = (ref_size, ref_size)
image = image.resize(model_input_size, Image.BILINEAR)
return image

sess = onnxruntime.InferenceSession(checkpoint_path)
sess = load_onnx_model(checkpoint_path)

orig_image = Image.fromarray(input_image)
image = resize_rmbg_image(orig_image)
Expand Down Expand Up @@ -203,13 +238,18 @@ def resize_rmbg_image(image):


def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512):
if not os.path.exists(checkpoint_path):
print(f"Checkpoint file not found: {checkpoint_path}")
return None

try:
import MNN.expr as expr
import MNN.nn as nn
except ImportError as e:
raise ImportError(
"The MNN module is not installed or there was an import error. Please ensure that the MNN library is installed by using the command 'pip install mnn'."
) from e

config = {}
config["precision"] = "low" # 当硬件支持(armv8.2)时使用fp16推理
config["backend"] = 0 # CPU
Expand Down

0 comments on commit f6bb43e

Please sign in to comment.