Skip to content

Commit

Permalink
check for fix prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Oct 30, 2024
1 parent 9e37da6 commit d02e97d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions hot_fair_utilities/inference/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@ def predict(
for idx in range(0, len(image_paths), BATCH_SIZE):
batch = image_paths[idx:idx + BATCH_SIZE]
for i, r in enumerate(model(batch, stream=True, conf=confidence, verbose=False)):
if r.masks is None:
preds = np.zeros((IMAGE_SIZE, IMAGE_SIZE,), dtype=np.float32)
else:
preds = r.masks.data.max(dim=0)[0] # dim=0 means to take only footprint
if hasattr(r, 'masks') and r.masks is not None: # Check for segmentation output
preds = r.masks.data.max(dim=0)[0] # Take only footprint
preds = torch.where(preds > confidence, torch.tensor(1), torch.tensor(0))
preds = preds.detach().cpu().numpy()
else:
preds = np.zeros((IMAGE_SIZE, IMAGE_SIZE,), dtype=np.float32) # Default if no masks

save_mask(preds, str(f"{prediction_path}/{Path(batch[i]).stem}.png"))
else:
raise RuntimeError("Loaded model is not supported")
Expand Down

0 comments on commit d02e97d

Please sign in to comment.