From 8fbe6bcb31b6dd698c1554376335b907ea3a7b64 Mon Sep 17 00:00:00 2001 From: AnnurAfgoni Date: Tue, 21 May 2024 21:34:43 +0800 Subject: [PATCH] fix: add device parameter and locate model and data to device #16 --- src/onepiece_classify/infer/recognition.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/onepiece_classify/infer/recognition.py b/src/onepiece_classify/infer/recognition.py index 37f3a50..c11a987 100644 --- a/src/onepiece_classify/infer/recognition.py +++ b/src/onepiece_classify/infer/recognition.py @@ -12,8 +12,9 @@ class ImageRecognition(BaseInference): - def __init__(self, model_path: str): + def __init__(self, model_path: str, device: str): self.model_path = Path(model_path) + self.device = device self.class_dict = { 0: "Ace", 1: "Akainu", @@ -39,7 +40,7 @@ def __init__(self, model_path: str): def _build_model(self): # load model - state_dict = torch.load(self.model_path) + state_dict = torch.load(self.model_path, map_location=self.device) model_backbone = image_recog(self.nclass) model_backbone.load_state_dict(state_dict) return model_backbone @@ -65,7 +66,7 @@ def pre_process( else: print("Image type not recognized") - return img + return img.to(self.device) def forward(self, image_tensor: torch.Tensor) -> torch.Tensor: self.model.eval()