Skip to content

Commit

Permalink
Merge pull request #17 from lombokai/bugfix/predict
Browse files Browse the repository at this point in the history
Bugfix/predict
  • Loading branch information
nunenuh authored May 22, 2024
2 parents 6f566c6 + 351dac6 commit b98c4eb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
5 changes: 3 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

def predict_image(
image_path: str = typer.Argument(help="image path", show_default=True),
model_path: str = typer.Argument("checkpoint/checkpoint_notebook.pth", help="model path (pth)", show_default=True)
model_path: str = typer.Argument("checkpoint/checkpoint_notebook.pth", help="model path (pth)", show_default=True),
device: str = typer.Argument("cpu", help="use cuda if your device has cuda", show_default=True)
):
predictor = ImageRecognition(model_path=model_path)
predictor = ImageRecognition(model_path=model_path, device=device)
result = predictor.predict(image=image_path)
typer.echo(f"Prediction: {result}")

Expand Down
7 changes: 4 additions & 3 deletions src/onepiece_classify/infer/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@


class ImageRecognition(BaseInference):
def __init__(self, model_path: str):
def __init__(self, model_path: str, device: str):
self.model_path = Path(model_path)
self.device = device
self.class_dict = {
0: "Ace",
1: "Akainu",
Expand All @@ -39,7 +40,7 @@ def __init__(self, model_path: str):

def _build_model(self):
# load model
state_dict = torch.load(self.model_path)
state_dict = torch.load(self.model_path, map_location=self.device)
model_backbone = image_recog(self.nclass)
model_backbone.load_state_dict(state_dict)
return model_backbone
Expand All @@ -65,7 +66,7 @@ def pre_process(
else:
print("Image type not recognized")

return img
return img.to(self.device)

def forward(self, image_tensor: torch.Tensor) -> torch.Tensor:
self.model.eval()
Expand Down

0 comments on commit b98c4eb

Please sign in to comment.