diff --git a/hot_fair_utilities/inference/predict.py b/hot_fair_utilities/inference/predict.py index 177ce0d..7d6a69d 100644 --- a/hot_fair_utilities/inference/predict.py +++ b/hot_fair_utilities/inference/predict.py @@ -76,12 +76,13 @@ def predict( for idx in range(0, len(image_paths), BATCH_SIZE): batch = image_paths[idx:idx + BATCH_SIZE] for i, r in enumerate(model(batch, stream=True, conf=confidence, verbose=False)): - if r.masks is None: - preds = np.zeros((IMAGE_SIZE, IMAGE_SIZE,), dtype=np.float32) - else: - preds = r.masks.data.max(dim=0)[0] # dim=0 means to take only footprint + if hasattr(r, 'masks') and r.masks is not None: # Check for segmentation output + preds = r.masks.data.max(dim=0)[0] # Take only footprint preds = torch.where(preds > confidence, torch.tensor(1), torch.tensor(0)) preds = preds.detach().cpu().numpy() + else: + preds = np.zeros((IMAGE_SIZE, IMAGE_SIZE,), dtype=np.float32) # Default if no masks + save_mask(preds, str(f"{prediction_path}/{Path(batch[i]).stem}.png")) else: raise RuntimeError("Loaded model is not supported")