diff --git a/src/deep_image_matching/extractors/extractor_base.py b/src/deep_image_matching/extractors/extractor_base.py index a3600a3..dfca7a2 100644 --- a/src/deep_image_matching/extractors/extractor_base.py +++ b/src/deep_image_matching/extractors/extractor_base.py @@ -205,7 +205,7 @@ def extract(self, img: Union[Image, Path, str]) -> np.ndarray: save_features_h5( feature_path, features, - im_path.name, + img.name, as_half=self.features_as_half, ) @@ -435,4 +435,4 @@ def viz_keypoints( [int(cv2.IMWRITE_JPEG_QUALITY), jpg_quality], ) else: - cv2.imwrite(out_path, out) + cv2.imwrite(out_path, out) \ No newline at end of file diff --git a/src/deep_image_matching/extractors/no_extractor.py b/src/deep_image_matching/extractors/no_extractor.py index 1675966..ba06ec0 100644 --- a/src/deep_image_matching/extractors/no_extractor.py +++ b/src/deep_image_matching/extractors/no_extractor.py @@ -43,7 +43,6 @@ def extract(self, img: Union[Image, Path, str]) -> np.ndarray: output_dir = Path(self.config["general"]["output_dir"]) feature_path = output_dir / "features.h5" output_dir.mkdir(parents=True, exist_ok=True) - im_name = im_path.name # Build fake features features = {} @@ -51,6 +50,7 @@ def extract(self, img: Union[Image, Path, str]) -> np.ndarray: features["descriptors"] = np.array([]) features["scores"] = np.array([]) img_obj = Image(im_path) + im_name = img_obj.name # img_obj.read_exif() features["image_size"] = np.array(img_obj.size) features["tile_idx"] = np.array([]) @@ -99,4 +99,4 @@ def _frame2tensor(self, image: np.ndarray, device: str = "cuda"): if __name__ == "__main__": - pass + pass \ No newline at end of file diff --git a/src/deep_image_matching/image_matching.py b/src/deep_image_matching/image_matching.py index f7abe27..54001d8 100644 --- a/src/deep_image_matching/image_matching.py +++ b/src/deep_image_matching/image_matching.py @@ -420,6 +420,7 @@ def rotate_upright_images( pairs = [(item[0].name, item[1].name) for item in self.pairs] path_to_upright_dir = self.output_dir / "upright_images" os.makedirs(path_to_upright_dir, exist_ok=False) + # I guess will break here, use recursive folder iterator images = os.listdir(self.image_dir) logger.info(f"Copying images to {path_to_upright_dir}") @@ -745,4 +746,4 @@ def rotate_back_features(self, feature_path: Path) -> None: if isinstance(v, np.ndarray): grp.create_dataset(k, data=v) - logger.info("Features rotated back.") + logger.info("Features rotated back.") \ No newline at end of file diff --git a/src/deep_image_matching/io/h5_to_db.py b/src/deep_image_matching/io/h5_to_db.py index 4e9ed71..13c53fd 100644 --- a/src/deep_image_matching/io/h5_to_db.py +++ b/src/deep_image_matching/io/h5_to_db.py @@ -25,9 +25,11 @@ import h5py import numpy as np import yaml -from PIL import ExifTags, Image +from PIL import ExifTags +from PIL import Image as PIL_Image from tqdm import tqdm +from ..utils.image import Image from ..utils.database import COLMAPDatabase, image_ids_to_pair_id logger = logging.getLogger("dim") @@ -127,7 +129,7 @@ def get_focal(image_path: Path, err_on_default: bool = False) -> float: This function calculates the focal length based on the maximum size of the image and the EXIF data. If the focal length cannot be determined from the EXIF data, it uses a default prior value. """ - image = Image.open(image_path) + image = PIL_Image.open(image_path) max_size = max(image.size) exif = image.getexif() @@ -156,7 +158,7 @@ def get_focal(image_path: Path, err_on_default: bool = False) -> float: def create_camera(db: Path, image_path: Path, camera_model: str): - image = Image.open(image_path) + image = PIL_Image.open(image_path) width, height = image.size focal = get_focal(image_path) @@ -237,7 +239,7 @@ def add_keypoints(db: Path, h5_path: Path, image_path: Path, camera_options: dic with h5py.File(str(h5_path), "r") as keypoint_f: fname_to_id = {} - k = 0 + created_cameras = {} for filename in tqdm(list(keypoint_f.keys())): keypoints = keypoint_f[filename]["keypoints"].__array__() @@ -247,19 +249,31 @@ def add_keypoints(db: Path, h5_path: Path, image_path: Path, camera_options: dic if filename not in list(grouped_images.keys()): if camera_options["general"]["single_camera"] is False: - camera_id = create_camera(db, path, camera_options["general"]["camera_model"]) + image = Image(path) + if image.camera_id != None: + if image.camera_id not in created_cameras: + camera_id = create_camera( + db, path, camera_options[f"cam{image.camera_id}"]["camera_model"] + ) + created_cameras[image.camera_id] = camera_id + else: + camera_id = created_cameras[image.camera_id] + else: + camera_id = create_camera( + db, path, camera_options["general"]["camera_model"] + ) + created_cameras[camera_id] = camera_id elif camera_options["general"]["single_camera"] is True: - if k == 0: - camera_id = create_camera(db, path, camera_options["general"]["camera_model"]) + if len(created_cameras) == 0: + camera_id = create_camera( + db, path, camera_options["general"]["camera_model"] + ) single_camera_id = camera_id - k += 1 - elif k > 0: + created_cameras[camera_id] = camera_id + else: camera_id = single_camera_id - elif filename in list(grouped_images.keys()): - camera_id = grouped_images[filename]["camera_id"] else: - print('ERROR in h5_to_db.py') - quit() + camera_id = grouped_images[filename]["camera_id"] image_id = db.add_image(filename, camera_id) fname_to_id[filename] = image_id @@ -402,4 +416,4 @@ def add_matches(db, h5_path, fname_to_id): fname_to_id, ) - db.commit() + db.commit() \ No newline at end of file diff --git a/src/deep_image_matching/matchers/loftr.py b/src/deep_image_matching/matchers/loftr.py index bec1eeb..5c8720b 100644 --- a/src/deep_image_matching/matchers/loftr.py +++ b/src/deep_image_matching/matchers/loftr.py @@ -8,6 +8,7 @@ from ..constants import TileSelection, Timer from ..utils.tiling import Tiler +from ..utils.image import Image from .matcher_base import DetectorFreeMatcherBase, tile_selection logger = logging.getLogger("dim") @@ -92,13 +93,15 @@ def _match_pairs( Raises: torch.cuda.OutOfMemoryError: If an out-of-memory error occurs while matching images. """ - - img0_name = img0_path.name - img1_name = img1_path.name + # Could just rename args but they might be used as keyword args elsewhere + img0 = img0_path + img1 = img1_path + img0_name = img0.name + img1_name = img0.name # Load images - image0 = self._load_image_np(img0_path) - image1 = self._load_image_np(img1_path) + image0 = self._load_image_np(img0.path) + image1 = self._load_image_np(img1.path) # Resize images if needed image0_ = self._resize_image(self._quality, image0) @@ -282,4 +285,4 @@ def _frame2tensor(self, image: np.ndarray, device: str = "cpu") -> torch.Tensor: if image.shape[1] > 2: image = K.color.bgr_to_rgb(image) image = K.color.rgb_to_grayscale(image) - return image + return image \ No newline at end of file diff --git a/src/deep_image_matching/matchers/matcher_base.py b/src/deep_image_matching/matchers/matcher_base.py index 8f8fba5..f6615c4 100644 --- a/src/deep_image_matching/matchers/matcher_base.py +++ b/src/deep_image_matching/matchers/matcher_base.py @@ -17,7 +17,7 @@ from ..thirdparty.hloc.extractors.superpoint import SuperPoint from ..thirdparty.LightGlue.lightglue import LightGlue from ..utils.geometric_verification import geometric_verification -from ..utils.image import resize_image +from ..utils.image import resize_image, Image from ..utils.tiling import Tiler from ..visualization import viz_matches_cv2, viz_matches_mpl @@ -205,10 +205,12 @@ def match( self._feature_path = Path(feature_path) # Get features from h5 file + img0 = Image(img0) + img1 = Image(img1) img0_name = img0.name img1_name = img1.name - features0 = get_features(self._feature_path, img0.name) - features1 = get_features(self._feature_path, img1.name) + features0 = get_features(self._feature_path, img0_name) + features1 = get_features(self._feature_path, img1_name) timer_match.update("load h5 features") # Perform matching (on tiles or full images) @@ -328,8 +330,8 @@ def match( self.viz_matches( feature_path, matches_path, - img0, - img1, + img0.path, + img1.path, save_path=viz_dir / f"{img0_name}_{img1_name}.jpg", img_format="jpg", jpg_quality=70, @@ -466,14 +468,14 @@ def viz_matches( jpg_quality = kwargs.get("jpg_quality", 80) hide_matching_track = kwargs.get("hide_matching_track", False) - img0 = Path(img0) - img1 = Path(img1) + img0 = Image(img0) + img1 = Image(img1) img0_name = img0.name img1_name = img1.name # Load images - image0 = load_image_np(img0, as_float=False, grayscale=True) - image1 = load_image_np(img1, as_float=False, grayscale=True) + image0 = load_image_np(img0.path, as_float=False, grayscale=True) + image1 = load_image_np(img1.path, as_float=False, grayscale=True) # Load features and matches features0 = get_features(feature_path, img0_name) @@ -648,8 +650,8 @@ def match( else: self._feature_path = Path(feature_path) - img0 = Path(img0) - img1 = Path(img1) + img0 = Image(img0) + img1 = Image(img1) img0_name = img0.name img1_name = img1.name @@ -672,7 +674,7 @@ def match( features1 = get_features(feature_path, img1_name) # Rescale threshold according the image original image size - img_shape = cv2.imread(str(img0)).shape + img_shape = cv2.imread(img0.path).shape scale_fct = np.floor(max(img_shape) / self.max_tile_size / 2) gv_threshold = self.config["general"]["gv_threshold"] * scale_fct @@ -854,14 +856,14 @@ def viz_matches( logger.warning("interactive_viz is ignored if fast_viz is True") assert save_path is not None, "output_dir must be specified if fast_viz is True" - img0 = Path(img0) - img1 = Path(img1) + img0 = Image(img0) + img1 = Image(img1) img0_name = img0.name img1_name = img1.name # Load images - image0 = load_image_np(img0, self.as_float, self.grayscale) - image1 = load_image_np(img1, self.as_float, self.grayscale) + image0 = load_image_np(img0.path, self.as_float, self.grayscale) + image1 = load_image_np(img1.path, self.as_float, self.grayscale) # Load features and matches features0 = get_features(feature_path, img0_name) @@ -1181,4 +1183,4 @@ def sp2lg(feats: dict) -> dict: def rbd2np(data: dict) -> dict: """Remove batch dimension from elements in data""" - return {k: v[0].cpu().numpy() if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items()} + return {k: v[0].cpu().numpy() if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items()} \ No newline at end of file diff --git a/src/deep_image_matching/matchers/roma.py b/src/deep_image_matching/matchers/roma.py index 50cd6fe..84f2bcd 100644 --- a/src/deep_image_matching/matchers/roma.py +++ b/src/deep_image_matching/matchers/roma.py @@ -13,6 +13,7 @@ from ..io.h5 import get_features from ..thirdparty.RoMa.roma import roma_outdoor from ..utils.geometric_verification import geometric_verification +from ..utils.image import Image from ..utils.tiling import Tiler from ..visualization import viz_matches_cv2 from .matcher_base import DetectorFreeMatcherBase, tile_selection @@ -115,8 +116,8 @@ def match( self._feature_path = Path(feature_path) # Get features from h5 file - img0 = Path(img0) - img1 = Path(img1) + img0 = Image(img0) + img1 = Image(img1) img0_name = img0.name img1_name = img1.name @@ -139,7 +140,7 @@ def match( features1 = get_features(feature_path, img1_name) # Rescale threshold according the image original image size - img_shape = cv2.imread(str(img0)).shape + img_shape = cv2.imread(img0.path).shape tile_size = max(self.config["general"]["tile_size"]) scale_fct = np.floor(max(img_shape) / tile_size / 2) gv_threshold = self.config["general"]["gv_threshold"] * scale_fct @@ -172,8 +173,8 @@ def match( def _match_pairs( self, feature_path: Path, - img0_path: Path, - img1_path: Path, + img0: Image, + img1: Image ): """ Perform matching between feature pairs. @@ -187,12 +188,12 @@ def _match_pairs( np.ndarray: Array containing the indices of matched keypoints. """ - img0_name = img0_path.name - img1_name = img1_path.name + img0_name = img0.name + img1_name = img1.name # Run inference - W_A, H_A = Image.open(img0_path).size - W_B, H_B = Image.open(img1_path).size + W_A, H_A = Image.open(img0.path).size + W_B, H_B = Image.open(img1.path).size #for path in [str(img0_path), str(img1_path)]: # image = cv2.imread(path, cv2.IMREAD_UNCHANGED) @@ -200,7 +201,7 @@ def _match_pairs( # image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # cv2.imwrite(path, image_rgb) - warp, certainty = self.matcher.match(str(img0_path), str(img1_path), device=self._device) + warp, certainty = self.matcher.match(img0.path, img1.path, device=self._device) matches, certainty = self.matcher.sample(warp, certainty) kptsA, kptsB = self.matcher.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) kptsA, kptsB = kptsA.cpu().numpy(), kptsB.cpu().numpy() @@ -283,8 +284,8 @@ def write_tiles_disk(output_dir: Path, tiles: dict) -> None: timer.update("tile selection") # Read images and resize them if needed - image0 = cv2.imread(str(img0)) - image1 = cv2.imread(str(img1)) + image0 = cv2.imread(img0.path) + image1 = cv2.imread(img1.path) image0 = self._resize_image(self._quality, image0) image1 = self._resize_image(self._quality, image1) @@ -431,4 +432,4 @@ def kps_in_image(kp, img_size, border_thr=2): if not self.keep_tiles: shutil.rmtree(tiles_dir) - return matches + return matches \ No newline at end of file diff --git a/src/deep_image_matching/utils/geometric_verification.py b/src/deep_image_matching/utils/geometric_verification.py index 3e24e26..6c66d79 100644 --- a/src/deep_image_matching/utils/geometric_verification.py +++ b/src/deep_image_matching/utils/geometric_verification.py @@ -180,4 +180,4 @@ def geometric_verification( confidence=0.9999, max_iters=10000, ) - print(F) + print(F) \ No newline at end of file diff --git a/src/deep_image_matching/utils/image.py b/src/deep_image_matching/utils/image.py index 5256322..bb5c085 100644 --- a/src/deep_image_matching/utils/image.py +++ b/src/deep_image_matching/utils/image.py @@ -111,20 +111,39 @@ def __init__(self, path: Union[str, Path], id: int = None) -> None: self._exif_data = None self._date_time = None self._focal_length = None - + self._camera_id = None + self._name = None + if "cam" in str(path): + for i, part in enumerate(path.parts): + if part.startswith("cam") and part[3:].isdigit(): + self._camera_id = eval(part[3:]) + rel_path = Path(*path.parts[i:]) + self._name = str(rel_path) + else: + self._name = path.name + try: self.read_exif() except Exception: img = PIL.Image.open(path) self._width, self._height = img.size + + img = PIL.Image.open(path) + self._width, self._height = img.size + def __repr__(self) -> str: """Returns a string representation of the image""" - return f"Image {self._path}" + return f"Image {self.name}" def __str__(self) -> str: """Returns a string representation of the image""" - return f"Image {self._path}" + return f"Image {self.name}" + + @property + def camera_id(self) -> int: + """Returns the camera_id of the image, if defined""" + return self._camera_id @property def id(self) -> int: @@ -136,12 +155,12 @@ def id(self) -> int: @property def name(self) -> str: """Returns the name of the image (including extension)""" - return self._path.name + return self._name @property def stem(self) -> str: """Returns the name of the image (excluding extension)""" - return self._path.stem + return self._name.stem @property def path(self) -> Path: @@ -408,7 +427,7 @@ def __init__(self, img_dir: Path): self.images = [] self.current_idx = 0 i = 0 - all_imgs = [image for image in img_dir.glob("*") if image.suffix in self.IMAGE_EXT] + all_imgs = [image for image in img_dir.rglob("*") if image.suffix in self.IMAGE_EXT] all_imgs.sort() if len(all_imgs) == 0: @@ -478,4 +497,4 @@ def img_paths(self): img_list = ImageList(image_dir) - print("done") + print("done") \ No newline at end of file