Skip to content

Commit

Permalink
fix: add device parameter and locate model and data to device #16
Browse files Browse the repository at this point in the history
  • Loading branch information
nurgoni committed May 21, 2024
1 parent 8f254e9 commit 8fbe6bc
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/onepiece_classify/infer/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@


class ImageRecognition(BaseInference):
def __init__(self, model_path: str):
def __init__(self, model_path: str, device: str):
self.model_path = Path(model_path)
self.device = device
self.class_dict = {
0: "Ace",
1: "Akainu",
Expand All @@ -39,7 +40,7 @@ def __init__(self, model_path: str):

def _build_model(self):
# load model
state_dict = torch.load(self.model_path)
state_dict = torch.load(self.model_path, map_location=self.device)
model_backbone = image_recog(self.nclass)
model_backbone.load_state_dict(state_dict)
return model_backbone
Expand All @@ -65,7 +66,7 @@ def pre_process(
else:
print("Image type not recognized")

return img
return img.to(self.device)

def forward(self, image_tensor: torch.Tensor) -> torch.Tensor:
self.model.eval()
Expand Down

0 comments on commit 8fbe6bc

Please sign in to comment.