From 8fbe6bcb31b6dd698c1554376335b907ea3a7b64 Mon Sep 17 00:00:00 2001 From: AnnurAfgoni Date: Tue, 21 May 2024 21:34:43 +0800 Subject: [PATCH 1/2] fix: add device parameter and locate model and data to device #16 --- src/onepiece_classify/infer/recognition.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/onepiece_classify/infer/recognition.py b/src/onepiece_classify/infer/recognition.py index 37f3a50..c11a987 100644 --- a/src/onepiece_classify/infer/recognition.py +++ b/src/onepiece_classify/infer/recognition.py @@ -12,8 +12,9 @@ class ImageRecognition(BaseInference): - def __init__(self, model_path: str): + def __init__(self, model_path: str, device: str): self.model_path = Path(model_path) + self.device = device self.class_dict = { 0: "Ace", 1: "Akainu", @@ -39,7 +40,7 @@ def __init__(self, model_path: str): def _build_model(self): # load model - state_dict = torch.load(self.model_path) + state_dict = torch.load(self.model_path, map_location=self.device) model_backbone = image_recog(self.nclass) model_backbone.load_state_dict(state_dict) return model_backbone @@ -65,7 +66,7 @@ def pre_process( else: print("Image type not recognized") - return img + return img.to(self.device) def forward(self, image_tensor: torch.Tensor) -> torch.Tensor: self.model.eval() From 351dac6e5ac414d8c6fb6a891949285721897830 Mon Sep 17 00:00:00 2001 From: AnnurAfgoni Date: Tue, 21 May 2024 21:35:52 +0800 Subject: [PATCH 2/2] fix: add device parameter with default value cpu --- predict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/predict.py b/predict.py index 36e4d65..ccfe2bc 100644 --- a/predict.py +++ b/predict.py @@ -8,9 +8,10 @@ def predict_image( image_path: str = typer.Argument(help="image path", show_default=True), - model_path: str = typer.Argument("checkpoint/checkpoint_notebook.pth", help="model path (pth)", show_default=True) + model_path: str = typer.Argument("checkpoint/checkpoint_notebook.pth", help="model path (pth)", show_default=True), + device: str = typer.Argument("cpu", help="use cuda if your device has cuda", show_default=True) ): - predictor = ImageRecognition(model_path=model_path) + predictor = ImageRecognition(model_path=model_path, device=device) result = predictor.predict(image=image_path) typer.echo(f"Prediction: {result}")