diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a7806059afaa..7feeebcab54a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -687,6 +687,8 @@ title: SegFormer - local: model_doc/seggpt title: SegGpt + - local: model_doc/superglue + title: SuperGlue - local: model_doc/superpoint title: SuperPoint - local: model_doc/swiftformer diff --git a/docs/source/en/index.md b/docs/source/en/index.md index aaff45ab65df..56dc9c1c2422 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -303,6 +303,7 @@ Flax), PyTorch, and/or TensorFlow. | [SqueezeBERT](model_doc/squeezebert) | ✅ | ❌ | ❌ | | [StableLm](model_doc/stablelm) | ✅ | ❌ | ❌ | | [Starcoder2](model_doc/starcoder2) | ✅ | ❌ | ❌ | +| [SuperGlue](model_doc/superglue) | ✅ | ❌ | ❌ | | [SuperPoint](model_doc/superpoint) | ✅ | ❌ | ❌ | | [SwiftFormer](model_doc/swiftformer) | ✅ | ✅ | ❌ | | [Swin Transformer](model_doc/swin) | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/superglue.md b/docs/source/en/model_doc/superglue.md new file mode 100644 index 000000000000..08a4575dddc2 --- /dev/null +++ b/docs/source/en/model_doc/superglue.md @@ -0,0 +1,138 @@ + + +# SuperGlue + +## Overview + +The SuperGlue model was proposed in [SuperGlue: Learning Feature Matching with Graph Neural Networks](https://arxiv.org/abs/1911.11763) by Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz and Andrew Rabinovich. + +This model consists of matching two sets of interest points detected in an image. Paired with the +[SuperPoint model](https://huggingface.co/magic-leap-community/superpoint), it can be used to match two images and +estimate the pose between them. This model is useful for tasks such as image matching, homography estimation, etc. + +The abstract from the paper is the following: + +*This paper introduces SuperGlue, a neural network that matches two sets of local features by jointly finding correspondences +and rejecting non-matchable points. Assignments are estimated by solving a differentiable optimal transport problem, whose costs +are predicted by a graph neural network. We introduce a flexible context aggregation mechanism based on attention, enabling +SuperGlue to reason about the underlying 3D scene and feature assignments jointly. Compared to traditional, hand-designed heuristics, +our technique learns priors over geometric transformations and regularities of the 3D world through end-to-end training from image +pairs. SuperGlue outperforms other learned approaches and achieves state-of-the-art results on the task of pose estimation in +challenging real-world indoor and outdoor environments. The proposed method performs matching in real-time on a modern GPU and +can be readily integrated into modern SfM or SLAM systems. The code and trained weights are publicly available at this [URL](https://github.com/magicleap/SuperGluePretrainedNetwork).* + +## How to use + +Here is a quick example of using the model. Since this model is an image matching model, it requires pairs of images to be matched. +The raw outputs contain the list of keypoints detected by the keypoint detector as well as the list of matches with their corresponding +matching scores. +```python +from transformers import AutoImageProcessor, AutoModel +import torch +from PIL import Image +import requests + +url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg" +image1 = Image.open(requests.get(url_image1, stream=True).raw) +url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg" +image_2 = Image.open(requests.get(url_image2, stream=True).raw) + +images = [image1, image2] + +processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor") +model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor") + +inputs = processor(images, return_tensors="pt") +with torch.no_grad(): + outputs = model(**inputs) +``` + +You can use the `post_process_keypoint_matching` method from the `SuperGlueImageProcessor` to get the keypoints and matches in a more readable format: + +```python +image_sizes = [[(image.height, image.width) for image in images]] +outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2) +for i, output in enumerate(outputs): + print("For the image pair", i) + for keypoint0, keypoint1, matching_score in zip( + output["keypoints0"], output["keypoints1"], output["matching_scores"] + ): + print( + f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}." + ) + +``` + +From the outputs, you can visualize the matches between the two images using the following code: +```python +import matplotlib.pyplot as plt +import numpy as np + +# Create side by side image +merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3)) +merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0 +merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0 +plt.imshow(merged_image) +plt.axis("off") + +# Retrieve the keypoints and matches +output = outputs[0] +keypoints0 = output["keypoints0"] +keypoints1 = output["keypoints1"] +matching_scores = output["matching_scores"] +keypoints0_x, keypoints0_y = keypoints0[:, 0].numpy(), keypoints0[:, 1].numpy() +keypoints1_x, keypoints1_y = keypoints1[:, 0].numpy(), keypoints1[:, 1].numpy() + +# Plot the matches +for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, matching_scores +): + plt.plot( + [keypoint0_x, keypoint1_x + image1.width], + [keypoint0_y, keypoint1_y], + color=plt.get_cmap("RdYlGn")(matching_score.item()), + alpha=0.9, + linewidth=0.5, + ) + plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2) + plt.scatter(keypoint1_x + image1.width, keypoint1_y, c="black", s=2) + +# Save the plot +plt.savefig("matched_image.png", dpi=300, bbox_inches='tight') +plt.close() +``` + +![image/png](https://cdn-uploads.huggingface.co/production/uploads/632885ba1558dac67c440aa8/01ZYaLB1NL5XdA8u7yCo4.png) + +This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille). +The original code can be found [here](https://github.com/magicleap/SuperGluePretrainedNetwork). + +## SuperGlueConfig + +[[autodoc]] SuperGlueConfig + +## SuperGlueImageProcessor + +[[autodoc]] SuperGlueImageProcessor + +- preprocess + +## SuperGlueForKeypointMatching + +[[autodoc]] SuperGlueForKeypointMatching + +- forward +- post_process_keypoint_matching \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 47b43e0b9089..37a30256d3d5 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -759,6 +759,7 @@ ], "models.stablelm": ["StableLmConfig"], "models.starcoder2": ["Starcoder2Config"], + "models.superglue": ["SuperGlueConfig"], "models.superpoint": ["SuperPointConfig"], "models.swiftformer": ["SwiftFormerConfig"], "models.swin": ["SwinConfig"], @@ -1234,6 +1235,7 @@ _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) _import_structure["models.seggpt"].extend(["SegGptImageProcessor"]) _import_structure["models.siglip"].append("SiglipImageProcessor") + _import_structure["models.superglue"].extend(["SuperGlueImageProcessor"]) _import_structure["models.superpoint"].extend(["SuperPointImageProcessor"]) _import_structure["models.swin2sr"].append("Swin2SRImageProcessor") _import_structure["models.tvp"].append("TvpImageProcessor") @@ -3402,6 +3404,12 @@ "Starcoder2PreTrainedModel", ] ) + _import_structure["models.superglue"].extend( + [ + "SuperGlueForKeypointMatching", + "SuperGluePreTrainedModel", + ] + ) _import_structure["models.superpoint"].extend( [ "SuperPointForKeypointDetection", @@ -5664,6 +5672,7 @@ ) from .models.stablelm import StableLmConfig from .models.starcoder2 import Starcoder2Config + from .models.superglue import SuperGlueConfig from .models.superpoint import SuperPointConfig from .models.swiftformer import ( SwiftFormerConfig, @@ -6159,6 +6168,7 @@ from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor from .models.seggpt import SegGptImageProcessor from .models.siglip import SiglipImageProcessor + from .models.superglue import SuperGlueImageProcessor from .models.superpoint import SuperPointImageProcessor from .models.swin2sr import Swin2SRImageProcessor from .models.tvp import TvpImageProcessor @@ -7901,6 +7911,10 @@ Starcoder2Model, Starcoder2PreTrainedModel, ) + from .models.superglue import ( + SuperGlueForKeypointMatching, + SuperGluePreTrainedModel, + ) from .models.superpoint import ( SuperPointForKeypointDetection, SuperPointPreTrainedModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 9155f629e63f..75beb0b0f2d4 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -234,6 +234,7 @@ squeezebert, stablelm, starcoder2, + superglue, superpoint, swiftformer, swin, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 48625ea3f346..af3c7a16e345 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -259,6 +259,7 @@ ("squeezebert", "SqueezeBertConfig"), ("stablelm", "StableLmConfig"), ("starcoder2", "Starcoder2Config"), + ("superglue", "SuperGlueConfig"), ("superpoint", "SuperPointConfig"), ("swiftformer", "SwiftFormerConfig"), ("swin", "SwinConfig"), @@ -575,6 +576,7 @@ ("squeezebert", "SqueezeBERT"), ("stablelm", "StableLm"), ("starcoder2", "Starcoder2"), + ("superglue", "SuperGlue"), ("superpoint", "SuperPoint"), ("swiftformer", "SwiftFormer"), ("swin", "Swin Transformer"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index a8960d80acc8..21a68ef5b6b8 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -129,6 +129,7 @@ ("segformer", ("SegformerImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)), ("siglip", ("SiglipImageProcessor",)), + ("superglue", "SuperGlueImageProcessor"), ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")), ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")), ("swin2sr", ("Swin2SRImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 67c539fca664..6db6d44d6cd8 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -239,6 +239,7 @@ ("squeezebert", "SqueezeBertModel"), ("stablelm", "StableLmModel"), ("starcoder2", "Starcoder2Model"), + ("superglue", "SuperGlueForKeypointMatching"), ("swiftformer", "SwiftFormerModel"), ("swin", "SwinModel"), ("swin2sr", "Swin2SRModel"), diff --git a/src/transformers/models/superglue/__init__.py b/src/transformers/models/superglue/__init__.py new file mode 100644 index 000000000000..666f44026eff --- /dev/null +++ b/src/transformers/models/superglue/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_superglue import * + from .image_processing_superglue import * + from .modeling_superglue import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/superglue/configuration_superglue.py b/src/transformers/models/superglue/configuration_superglue.py new file mode 100644 index 000000000000..fe301442d632 --- /dev/null +++ b/src/transformers/models/superglue/configuration_superglue.py @@ -0,0 +1,120 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, List + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING + + +if TYPE_CHECKING: + from ..superpoint import SuperPointConfig + +logger = logging.get_logger(__name__) + + +class SuperGlueConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SuperGlueModel`]. It is used to instantiate a + SuperGlue model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SuperGlue + [magic-leap-community/superglue_indoor](https://huggingface.co/magic-leap-community/superglue_indoor) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`): + The config object or dictionary of the keypoint detector. + hidden_size (`int`, *optional*, defaults to 256): + The dimension of the descriptors. + keypoint_encoder_sizes (`List[int]`, *optional*, defaults to `[32, 64, 128, 256]`): + The sizes of the keypoint encoder layers. + gnn_layers_types (`List[str]`, *optional*, defaults to `['self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross']`): + The types of the GNN layers. Must be either 'self' or 'cross'. + num_attention_heads (`int`, *optional*, defaults to 4): + The number of heads in the GNN layers. + sinkhorn_iterations (`int`, *optional*, defaults to 100): + The number of Sinkhorn iterations. + matching_threshold (`float`, *optional*, defaults to 0.0): + The matching threshold. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Examples: + ```python + >>> from transformers import SuperGlueConfig, SuperGlueModel + + >>> # Initializing a SuperGlue superglue style configuration + >>> configuration = SuperGlueConfig() + + >>> # Initializing a model from the superglue style configuration + >>> model = SuperGlueModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "superglue" + + def __init__( + self, + keypoint_detector_config: "SuperPointConfig" = None, + hidden_size: int = 256, + keypoint_encoder_sizes: List[int] = None, + gnn_layers_types: List[str] = None, + num_attention_heads: int = 4, + sinkhorn_iterations: int = 100, + matching_threshold: float = 0.0, + initializer_range: float = 0.02, + **kwargs, + ): + self.gnn_layers_types = gnn_layers_types if gnn_layers_types is not None else ["self", "cross"] * 9 + # Check whether all gnn_layers_types are either 'self' or 'cross' + if not all(layer_type in ["self", "cross"] for layer_type in self.gnn_layers_types): + raise ValueError("All gnn_layers_types must be either 'self' or 'cross'") + + if hidden_size % num_attention_heads != 0: + raise ValueError("hidden_size % num_attention_heads is different from zero") + + self.keypoint_encoder_sizes = ( + keypoint_encoder_sizes if keypoint_encoder_sizes is not None else [32, 64, 128, 256] + ) + self.hidden_size = hidden_size + self.keypoint_encoder_sizes = keypoint_encoder_sizes + self.gnn_layers_types = gnn_layers_types + self.num_attention_heads = num_attention_heads + self.sinkhorn_iterations = sinkhorn_iterations + self.matching_threshold = matching_threshold + + if isinstance(keypoint_detector_config, dict): + keypoint_detector_config["model_type"] = ( + keypoint_detector_config["model_type"] if "model_type" in keypoint_detector_config else "superpoint" + ) + keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]]( + **keypoint_detector_config + ) + if keypoint_detector_config is None: + keypoint_detector_config = CONFIG_MAPPING["superpoint"]() + + self.keypoint_detector_config = keypoint_detector_config + self.initializer_range = initializer_range + self.attention_probs_dropout_prob = 0 + self.is_decoder = False + + super().__init__(**kwargs) + + +__all__ = ["SuperGlueConfig"] diff --git a/src/transformers/models/superglue/convert_superglue_to_hf.py b/src/transformers/models/superglue/convert_superglue_to_hf.py new file mode 100644 index 000000000000..cfff39acdfd8 --- /dev/null +++ b/src/transformers/models/superglue/convert_superglue_to_hf.py @@ -0,0 +1,342 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import os +import re +from typing import List + +import torch +from datasets import load_dataset + +from transformers import ( + AutoModelForKeypointDetection, + SuperGlueConfig, + SuperGlueForKeypointMatching, + SuperGlueImageProcessor, +) + + +def prepare_imgs(): + dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train") + image1 = dataset[0]["image"] + image2 = dataset[1]["image"] + image3 = dataset[2]["image"] + return [[image1, image2], [image3, image2]] + + +def verify_model_outputs(model, model_name, device): + images = prepare_imgs() + preprocessor = SuperGlueImageProcessor() + inputs = preprocessor(images=images, return_tensors="pt").to(device) + model.to(device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_matches_values = outputs.matches[0, 0, :10] + predicted_matching_scores_values = outputs.matching_scores[0, 0, :10] + + predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item() + + if "outdoor" in model_name: + expected_max_number_keypoints = 865 + expected_matches_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + expected_matching_scores_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + + expected_matches_values = torch.tensor( + [125, 630, 137, 138, 136, 143, 135, -1, -1, 153], dtype=torch.int64, device=device + ) + expected_matching_scores_values = torch.tensor( + [0.9899, 0.0033, 0.9897, 0.9889, 0.9879, 0.7464, 0.7109, 0, 0, 0.9841], device=device + ) + + expected_number_of_matches = 281 + elif "indoor" in model_name: + expected_max_number_keypoints = 865 + expected_matches_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + expected_matching_scores_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + + expected_matches_values = torch.tensor( + [125, 144, 137, 138, 136, 155, 135, -1, -1, 153], dtype=torch.int64, device=device + ) + expected_matching_scores_values = torch.tensor( + [0.9694, 0.0010, 0.9006, 0.8753, 0.8521, 0.5688, 0.6321, 0.0, 0.0, 0.7235], device=device + ) + + expected_number_of_matches = 282 + + assert outputs.matches.shape == expected_matches_shape + assert outputs.matching_scores.shape == expected_matching_scores_shape + + assert torch.allclose(predicted_matches_values, expected_matches_values, atol=1e-4) + assert torch.allclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-4) + + assert predicted_number_of_matches == expected_number_of_matches + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"kenc.encoder.(\d+)": r"keypoint_encoder.encoder.\1.old", + r"gnn.layers.(\d+).attn.proj.0": r"gnn.layers.\1.attention.self.query", + r"gnn.layers.(\d+).attn.proj.1": r"gnn.layers.\1.attention.self.key", + r"gnn.layers.(\d+).attn.proj.2": r"gnn.layers.\1.attention.self.value", + r"gnn.layers.(\d+).attn.merge": r"gnn.layers.\1.attention.output.dense", + r"gnn.layers.(\d+).mlp.0": r"gnn.layers.\1.mlp.0.linear", + r"gnn.layers.(\d+).mlp.1": r"gnn.layers.\1.mlp.0.batch_norm", + r"gnn.layers.(\d+).mlp.3": r"gnn.layers.\1.mlp.1", + r"final_proj": r"final_projection.final_proj", +} + + +def convert_old_keys_to_new_keys(state_dict_keys: List[str], conversion_mapping=ORIGINAL_TO_CONVERTED_KEY_MAPPING): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in conversion_mapping.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +def replace_state_dict_keys(all_keys, new_keys, original_state_dict): + state_dict = {} + for key in all_keys: + new_key = new_keys[key] + state_dict[new_key] = original_state_dict.pop(key).contiguous().clone() + return state_dict + + +def convert_state_dict(state_dict, config): + converted_to_final_key_mapping = {} + + def convert_conv_to_linear(keys): + for key in keys: + state_dict[key] = state_dict[key].squeeze(-1) + + def qkv_permute_weights_and_biases(keys, num_heads=4): + for key in keys: + tensor = state_dict[key] + shape = tensor.shape + dim_out = shape[0] + if len(shape) == 2: + dim_in = shape[1] + tensor = ( + tensor.reshape(dim_out // num_heads, num_heads, dim_in).permute(1, 0, 2).reshape(dim_out, dim_in) + ) + if len(shape) == 1: + tensor = tensor.reshape(dim_out // num_heads, num_heads).permute(1, 0).reshape(dim_out) + state_dict[key] = tensor + + def output_permute_weights(keys, num_heads=4): + for key in keys: + tensor = state_dict[key] + dim_in = tensor.shape[1] + dim_out = tensor.shape[0] + tensor = tensor.reshape(dim_out, dim_in // num_heads, num_heads).permute(0, 2, 1).reshape(dim_out, dim_in) + state_dict[key] = tensor + + conv_keys = [] + qkv_permute_keys = [] + output_permute_keys = [] + # Keypoint Encoder + keypoint_encoder_key = "keypoint_encoder.encoder" + for i in range(1, len(config.keypoint_encoder_sizes) + 2): + old_conv_key = f"{keypoint_encoder_key}.{(i - 1) * 3}.old" + new_index = i - 1 + new_conv_key = f"{keypoint_encoder_key}.{new_index}." + if i < len(config.keypoint_encoder_sizes) + 1: + new_conv_key = f"{new_conv_key}linear." + converted_to_final_key_mapping[rf"{old_conv_key}\."] = new_conv_key + if i < len(config.keypoint_encoder_sizes) + 1: + old_batch_norm_key = f"{keypoint_encoder_key}.{(i - 1) * 3 + 1}.old" + new_batch_norm_key = f"{keypoint_encoder_key}.{new_index}.batch_norm." + converted_to_final_key_mapping[rf"{old_batch_norm_key}\."] = new_batch_norm_key + + conv_keys.append(f"{old_conv_key}.weight") + + # Attentional GNN + for i in range(len(config.gnn_layers_types)): + gnn_layer_key = f"gnn.layers.{i}" + ## Attention + attention_key = f"{gnn_layer_key}.attention" + conv_keys.extend( + [ + f"{attention_key}.self.query.weight", + f"{attention_key}.self.key.weight", + f"{attention_key}.self.value.weight", + f"{attention_key}.output.dense.weight", + ] + ) + qkv_permute_keys.extend( + [ + f"{attention_key}.self.query.weight", + f"{attention_key}.self.key.weight", + f"{attention_key}.self.value.weight", + f"{attention_key}.self.query.bias", + f"{attention_key}.self.key.bias", + f"{attention_key}.self.value.bias", + ] + ) + output_permute_keys.append(f"{attention_key}.output.dense.weight") + + ## MLP + conv_keys.extend([f"{gnn_layer_key}.mlp.0.linear.weight", f"{gnn_layer_key}.mlp.1.weight"]) + + # Final Projection + conv_keys.append("final_projection.final_proj.weight") + + convert_conv_to_linear(conv_keys) + qkv_permute_weights_and_biases(qkv_permute_keys) + output_permute_weights(output_permute_keys) + all_keys = list(state_dict.keys()) + new_keys = convert_old_keys_to_new_keys(all_keys, converted_to_final_key_mapping) + state_dict = replace_state_dict_keys(all_keys, new_keys, state_dict) + return state_dict + + +def add_keypoint_detector_state_dict(superglue_state_dict): + keypoint_detector = AutoModelForKeypointDetection.from_pretrained("magic-leap-community/superpoint") + keypoint_detector_state_dict = keypoint_detector.state_dict() + for k, v in keypoint_detector_state_dict.items(): + superglue_state_dict[f"keypoint_detector.{k}"] = v + return superglue_state_dict + + +@torch.no_grad() +def write_model( + model_path, + checkpoint_url, + safe_serialization=True, + push_to_hub=False, +): + os.makedirs(model_path, exist_ok=True) + + # ------------------------------------------------------------ + # SuperGlue config + # ------------------------------------------------------------ + + config = SuperGlueConfig( + hidden_size=256, + keypoint_encoder_sizes=[32, 64, 128, 256], + gnn_layers_types=["self", "cross"] * 9, + sinkhorn_iterations=100, + matching_threshold=0.0, + ) + config.architectures = ["SuperGlueForKeypointMatching"] + config.save_pretrained(model_path, push_to_hub=push_to_hub) + print("Model config saved successfully...") + + # ------------------------------------------------------------ + # Convert weights + # ------------------------------------------------------------ + + print(f"Fetching all parameters from the checkpoint at {checkpoint_url}...") + original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url) + + print("Converting model...") + all_keys = list(original_state_dict.keys()) + new_keys = convert_old_keys_to_new_keys(all_keys) + + state_dict = replace_state_dict_keys(all_keys, new_keys, original_state_dict) + state_dict = convert_state_dict(state_dict, config) + + del original_state_dict + gc.collect() + state_dict = add_keypoint_detector_state_dict(state_dict) + + print("Loading the checkpoint in a SuperGlue model...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with torch.device(device): + model = SuperGlueForKeypointMatching(config) + model.load_state_dict(state_dict, strict=True) + print("Checkpoint loaded successfully...") + del model.config._name_or_path + + print("Saving the model...") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + del state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + model = SuperGlueForKeypointMatching.from_pretrained(model_path) + print("Model reloaded successfully.") + + model_name = "superglue" + if "superglue_outdoor.pth" in checkpoint_url: + model_name += "_outdoor" + elif "superglue_indoor.pth" in checkpoint_url: + model_name += "_indoor" + + print("Checking the model outputs...") + verify_model_outputs(model, model_name, device) + print("Model outputs verified successfully.") + + organization = "magic-leap-community" + if push_to_hub: + print("Pushing model to the hub...") + model.push_to_hub( + repo_id=f"{organization}/{model_name}", + commit_message="Add model", + ) + + write_image_processor(model_path, model_name, organization, push_to_hub=push_to_hub) + + +def write_image_processor(save_dir, model_name, organization, push_to_hub=False): + image_processor = SuperGlueImageProcessor() + image_processor.save_pretrained(save_dir) + + if push_to_hub: + print("Pushing image processor to the hub...") + image_processor.push_to_hub( + repo_id=f"{organization}/{model_name}", + commit_message="Add image processor", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default="https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/weights/superglue_indoor.pth", + type=str, + help="URL of the original SuperGlue checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--save_model", action="store_true", help="Save model to local") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Push model and image preprocessor to the hub", + ) + + args = parser.parse_args() + write_model( + args.pytorch_dump_folder_path, args.checkpoint_url, safe_serialization=True, push_to_hub=args.push_to_hub + ) diff --git a/src/transformers/models/superglue/image_processing_superglue.py b/src/transformers/models/superglue/image_processing_superglue.py new file mode 100644 index 000000000000..567e55580701 --- /dev/null +++ b/src/transformers/models/superglue/image_processing_superglue.py @@ -0,0 +1,407 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SuperPoint.""" + +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ... import is_torch_available, is_vision_available +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + ImageType, + PILImageResampling, + get_image_type, + infer_channel_dimension_format, + is_pil_image, + is_scaled_image, + is_valid_image, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, logging, requires_backends + + +if is_torch_available(): + import torch + +if TYPE_CHECKING: + from .modeling_superglue import KeypointMatchingOutput + +if is_vision_available(): + import PIL + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.superpoint.image_processing_superpoint.is_grayscale +def is_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +): + if input_data_format == ChannelDimension.FIRST: + if image.shape[0] == 1: + return True + return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]) + elif input_data_format == ChannelDimension.LAST: + if image.shape[-1] == 1: + return True + return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2]) + + +# Copied from transformers.models.superpoint.image_processing_superpoint.convert_to_grayscale +def convert_to_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> ImageInput: + """ + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch + and tensorflow grayscale conversion + + This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each + channel, because of an issue that is discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + image (Image): + The image to convert. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + requires_backends(convert_to_grayscale, ["vision"]) + + if isinstance(image, np.ndarray): + if is_grayscale(image, input_data_format=input_data_format): + return image + if input_data_format == ChannelDimension.FIRST: + gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=0) + elif input_data_format == ChannelDimension.LAST: + gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=-1) + return gray_image + + if not isinstance(image, PIL.Image.Image): + return image + + image = image.convert("L") + return image + + +def validate_and_format_image_pairs(images: ImageInput): + error_message = ( + "Input images must be a one of the following :", + " - A pair of PIL images.", + " - A pair of 3D arrays.", + " - A list of pairs of PIL images.", + " - A list of pairs of 3D arrays.", + ) + + def _is_valid_image(image): + """images is a PIL Image or a 3D array.""" + return is_pil_image(image) or ( + is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3 + ) + + if isinstance(images, list): + if len(images) == 2 and all((_is_valid_image(image)) for image in images): + return images + if all( + isinstance(image_pair, list) + and len(image_pair) == 2 + and all(_is_valid_image(image) for image in image_pair) + for image_pair in images + ): + return [image for image_pair in images for image in image_pair] + raise ValueError(error_message) + + +class SuperGlueImageProcessor(BaseImageProcessor): + r""" + Constructs a SuperGlue image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden + by `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`): + Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to + `True`. Can be overriden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess` + method. + do_grayscale (`bool`, *optional*, defaults to `True`): + Whether to convert the image to grayscale. Can be overriden by `do_grayscale` in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_grayscale: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 480, "width": 640} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_grayscale = do_grayscale + + # Copied from transformers.models.superpoint.image_processing_superpoint.SuperPointImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be inferred from the input + image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + size = get_size_dict(size, default_to_square=False) + + return resize( + image, + size=(size["height"], size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_grayscale: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image pairs to preprocess. Expects either a list of 2 images or a list of list of 2 images list with + pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set + `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image + is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the + image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`): + Whether to convert the image to grayscale. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + + # Validate and convert the input images into a flattened list of images for all subsequent processing steps. + images = validate_and_format_image_pairs(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_grayscale: + image = convert_to_grayscale(image, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + all_images.append(image) + + # Convert back the flattened list of images into a list of pairs of images. + image_pairs = [all_images[i : i + 2] for i in range(0, len(all_images), 2)] + + data = {"pixel_values": image_pairs} + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_keypoint_matching( + self, + outputs: "KeypointMatchingOutput", + target_sizes: Union[TensorType, List[Tuple]], + threshold: float = 0.0, + ) -> List[Dict[str, torch.Tensor]]: + """ + Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors + with coordinates absolute to the original image sizes. + Args: + outputs ([`KeypointMatchingOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*): + Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the + target size `(height, width)` of each image in the batch. This must be the original image size (before + any processing). + threshold (`float`, *optional*, defaults to 0.0): + Threshold to filter out the matches with low scores. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image + of the pair, the matching scores and the matching indices. + """ + if outputs.mask.shape[0] != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") + if not all(len(target_size) == 2 for target_size in target_sizes): + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + if isinstance(target_sizes, List): + image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device) + else: + if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2: + raise ValueError( + "Each element of target_sizes must contain the size (h, w) of each image of the batch" + ) + image_pair_sizes = target_sizes + + keypoints = outputs.keypoints.clone() + keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2) + keypoints = keypoints.to(torch.int32) + + results = [] + for mask_pair, keypoints_pair, matches, scores in zip( + outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0] + ): + mask0 = mask_pair[0] > 0 + mask1 = mask_pair[1] > 0 + keypoints0 = keypoints_pair[0][mask0] + keypoints1 = keypoints_pair[1][mask1] + matches0 = matches[mask0] + scores0 = scores[mask0] + + # Filter out matches with low scores + valid_matches = torch.logical_and(scores0 > threshold, matches0 > -1) + + matched_keypoints0 = keypoints0[valid_matches] + matched_keypoints1 = keypoints1[matches0[valid_matches]] + matching_scores = scores0[valid_matches] + + results.append( + { + "keypoints0": matched_keypoints0, + "keypoints1": matched_keypoints1, + "matching_scores": matching_scores, + } + ) + + return results + + +__all__ = ["SuperGlueImageProcessor"] diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py new file mode 100644 index 000000000000..049eb91b8451 --- /dev/null +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -0,0 +1,866 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SuperGlue model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from transformers import PreTrainedModel, add_start_docstrings +from transformers.models.superglue.configuration_superglue import SuperGlueConfig + +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging +from ..auto import AutoModelForKeypointDetection + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC_ = "SuperGlueConfig" +_CHECKPOINT_FOR_DOC_ = "magic-leap-community/superglue_indoor" + + +def concat_pairs(tensor_tuple0: Tuple[torch.Tensor], tensor_tuple1: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """ + Concatenate two tuples of tensors pairwise + + Args: + tensor_tuple0 (`Tuple[torch.Tensor]`): + Tuple of tensors. + tensor_tuple1 (`Tuple[torch.Tensor]`): + Tuple of tensors. + + Returns: + (`Tuple[torch.Tensor]`): Tuple of concatenated tensors. + """ + return tuple([torch.cat([tensor0, tensor1]) for tensor0, tensor1 in zip(tensor_tuple0, tensor_tuple1)]) + + +def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + Normalize keypoints locations based on image image_shape + + Args: + keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`): + Keypoints locations in (x, y) format. + height (`int`): + Image height. + width (`int`): + Image width. + + Returns: + Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`). + """ + size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None] + center = size / 2 + scaling = size.max(1, keepdim=True).values * 0.7 + return (keypoints - center[:, None, :]) / scaling[:, None, :] + + +def log_sinkhorn_iterations( + log_cost_matrix: torch.Tensor, + log_source_distribution: torch.Tensor, + log_target_distribution: torch.Tensor, + num_iterations: int, +) -> torch.Tensor: + """ + Perform Sinkhorn Normalization in Log-space for stability + + Args: + log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): + Logarithm of the cost matrix. + log_source_distribution (`torch.Tensor` of shape `(batch_size, num_rows)`): + Logarithm of the source distribution. + log_target_distribution (`torch.Tensor` of shape `(batch_size, num_columns)`): + Logarithm of the target distribution. + + Returns: + log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the optimal + transport matrix. + """ + log_u_scaling = torch.zeros_like(log_source_distribution) + log_v_scaling = torch.zeros_like(log_target_distribution) + for _ in range(num_iterations): + log_u_scaling = log_source_distribution - torch.logsumexp(log_cost_matrix + log_v_scaling.unsqueeze(1), dim=2) + log_v_scaling = log_target_distribution - torch.logsumexp(log_cost_matrix + log_u_scaling.unsqueeze(2), dim=1) + return log_cost_matrix + log_u_scaling.unsqueeze(2) + log_v_scaling.unsqueeze(1) + + +def log_optimal_transport(scores: torch.Tensor, reg_param: torch.Tensor, iterations: int) -> torch.Tensor: + """ + Perform Differentiable Optimal Transport in Log-space for stability + + Args: + scores: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): + Cost matrix. + reg_param: (`torch.Tensor` of shape `(batch_size, 1, 1)`): + Regularization parameter. + iterations: (`int`): + Number of Sinkhorn iterations. + + Returns: + log_optimal_transport_matrix: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the + optimal transport matrix. + """ + batch_size, num_rows, num_columns = scores.shape + one_tensor = scores.new_tensor(1) + num_rows_tensor, num_columns_tensor = (num_rows * one_tensor).to(scores), (num_columns * one_tensor).to(scores) + + source_reg_param = reg_param.expand(batch_size, num_rows, 1) + target_reg_param = reg_param.expand(batch_size, 1, num_columns) + reg_param = reg_param.expand(batch_size, 1, 1) + + couplings = torch.cat([torch.cat([scores, source_reg_param], -1), torch.cat([target_reg_param, reg_param], -1)], 1) + + log_normalization = -(num_rows_tensor + num_columns_tensor).log() + log_source_distribution = torch.cat( + [log_normalization.expand(num_rows), num_columns_tensor.log()[None] + log_normalization] + ) + log_target_distribution = torch.cat( + [log_normalization.expand(num_columns), num_rows_tensor.log()[None] + log_normalization] + ) + log_source_distribution, log_target_distribution = ( + log_source_distribution[None].expand(batch_size, -1), + log_target_distribution[None].expand(batch_size, -1), + ) + + log_optimal_transport_matrix = log_sinkhorn_iterations( + couplings, log_source_distribution, log_target_distribution, num_iterations=iterations + ) + log_optimal_transport_matrix = log_optimal_transport_matrix - log_normalization # multiply probabilities by M+N + return log_optimal_transport_matrix + + +def arange_like(x, dim: int) -> torch.Tensor: + return x.new_ones(x.shape[dim]).cumsum(0) - 1 + + +@dataclass +class KeypointMatchingOutput(ModelOutput): + """ + Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number + of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of + images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is + used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching + information. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + mask (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in matches and matching_scores are keypoint matching information. + matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)`, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`) + attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) + """ + + loss: Optional[torch.FloatTensor] = None + matches: Optional[torch.FloatTensor] = None + matching_scores: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.FloatTensor] = None + mask: Optional[torch.IntTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SuperGlueMultiLayerPerceptron(nn.Module): + def __init__(self, config: SuperGlueConfig, in_channels: int, out_channels: int) -> None: + super().__init__() + self.linear = nn.Linear(in_channels, out_channels) + self.batch_norm = nn.BatchNorm1d(out_channels) + self.activation = nn.ReLU() + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.linear(hidden_state) + hidden_state = hidden_state.transpose(-1, -2) + hidden_state = self.batch_norm(hidden_state) + hidden_state = hidden_state.transpose(-1, -2) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class SuperGlueKeypointEncoder(nn.Module): + def __init__(self, config: SuperGlueConfig) -> None: + super().__init__() + layer_sizes = config.keypoint_encoder_sizes + hidden_size = config.hidden_size + # 3 here consists of 2 for the (x, y) coordinates and 1 for the score of the keypoint + encoder_channels = [3] + layer_sizes + [hidden_size] + + layers = [ + SuperGlueMultiLayerPerceptron(config, encoder_channels[i - 1], encoder_channels[i]) + for i in range(1, len(encoder_channels) - 1) + ] + layers.append(nn.Linear(encoder_channels[-2], encoder_channels[-1])) + self.encoder = nn.ModuleList(layers) + + def forward( + self, + keypoints: torch.Tensor, + scores: torch.Tensor, + output_hidden_states: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + scores = scores.unsqueeze(2) + hidden_state = torch.cat([keypoints, scores], dim=2) + all_hidden_states = () if output_hidden_states else None + for layer in self.encoder: + hidden_state = layer(hidden_state) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + return hidden_state, all_hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->SuperGlue +class SuperGlueSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in SuperGlueModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class SuperGlueSelfOutput(nn.Module): + def __init__(self, config: SuperGlueConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + return hidden_states + + +SUPERGLUE_SELF_ATTENTION_CLASSES = { + "eager": SuperGlueSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->SuperGlue,BERT->SUPERGLUE +class SuperGlueAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = SuperGlueSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class SuperGlueAttentionalPropagation(nn.Module): + def __init__(self, config: SuperGlueConfig) -> None: + super().__init__() + hidden_size = config.hidden_size + self.attention = SuperGlueAttention(config) + mlp_channels = [hidden_size * 2, hidden_size * 2, hidden_size] + layers = [ + SuperGlueMultiLayerPerceptron(config, mlp_channels[i - 1], mlp_channels[i]) + for i in range(1, len(mlp_channels) - 1) + ] + layers.append(nn.Linear(mlp_channels[-2], mlp_channels[-1])) + self.mlp = nn.ModuleList(layers) + + def forward( + self, + descriptors: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor]]]: + attention_outputs = self.attention( + descriptors, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + output = attention_outputs[0] + attention = attention_outputs[1:] + + hidden_state = torch.cat([descriptors, output], dim=2) + + all_hidden_states = () if output_hidden_states else None + for layer in self.mlp: + hidden_state = layer(hidden_state) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + return hidden_state, all_hidden_states, attention + + +class SuperGlueAttentionalGNN(nn.Module): + def __init__(self, config: SuperGlueConfig) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.layers_types = config.gnn_layers_types + self.layers = nn.ModuleList([SuperGlueAttentionalPropagation(config) for _ in range(len(self.layers_types))]) + + def forward( + self, + descriptors: torch.Tensor, + mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple], Optional[Tuple]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + batch_size, num_keypoints, _ = descriptors.shape + if output_hidden_states: + all_hidden_states = all_hidden_states + (descriptors,) + + for gnn_layer, layer_type in zip(self.layers, self.layers_types): + encoder_hidden_states = None + encoder_attention_mask = None + if layer_type == "cross": + encoder_hidden_states = ( + descriptors.reshape(-1, 2, num_keypoints, self.hidden_size) + .flip(1) + .reshape(batch_size, num_keypoints, self.hidden_size) + ) + encoder_attention_mask = ( + mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints) + if mask is not None + else None + ) + + gnn_outputs = gnn_layer( + descriptors, + attention_mask=mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + delta = gnn_outputs[0] + + if output_hidden_states: + all_hidden_states = all_hidden_states + gnn_outputs[1] + if output_attentions: + all_attentions = all_attentions + gnn_outputs[2] + + descriptors = descriptors + delta + return descriptors, all_hidden_states, all_attentions + + +class SuperGlueFinalProjection(nn.Module): + def __init__(self, config: SuperGlueConfig) -> None: + super().__init__() + hidden_size = config.hidden_size + self.final_proj = nn.Linear(hidden_size, hidden_size, bias=True) + + def forward(self, descriptors: torch.Tensor) -> torch.Tensor: + return self.final_proj(descriptors) + + +class SuperGluePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SuperGlueConfig + base_model_prefix = "superglue" + main_input_name = "pixel_values" + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, SuperGlueMultiLayerPerceptron): + nn.init.constant_(module.linear.bias, 0.0) + + +SUPERGLUE_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SuperGlueConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + +SUPERGLUE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SuperGlueImageProcessor`]. See + [`SuperGlueImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "SuperGlue model taking images as inputs and outputting the matching of them.", + SUPERGLUE_START_DOCSTRING, +) +class SuperGlueForKeypointMatching(SuperGluePreTrainedModel): + """SuperGlue feature matching middle-end + + Given two sets of keypoints and locations, we determine the + correspondences by: + 1. Keypoint Encoding (normalization + visual feature and location fusion) + 2. Graph Neural Network with multiple self and cross-attention layers + 3. Final projection layer + 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm) + 5. Thresholding matrix based on mutual exclusivity and a match_threshold + + The correspondence ids use -1 to indicate non-matching points. + + Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew + Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural + Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763 + """ + + def __init__(self, config: SuperGlueConfig) -> None: + super().__init__(config) + + self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config) + + self.keypoint_encoder = SuperGlueKeypointEncoder(config) + self.gnn = SuperGlueAttentionalGNN(config) + self.final_projection = SuperGlueFinalProjection(config) + + bin_score = torch.nn.Parameter(torch.tensor(1.0)) + self.register_parameter("bin_score", bin_score) + + self.post_init() + + def _match_image_pair( + self, + keypoints: torch.Tensor, + descriptors: torch.Tensor, + scores: torch.Tensor, + height: int, + width: int, + mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple, Tuple]: + """ + Perform keypoint matching between two images. + + Args: + keypoints (`torch.Tensor` of shape `(batch_size, 2, num_keypoints, 2)`): + Keypoints detected in the pair of image. + descriptors (`torch.Tensor` of shape `(batch_size, 2, descriptor_dim, num_keypoints)`): + Descriptors of the keypoints detected in the image pair. + scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): + Confidence scores of the keypoints detected in the image pair. + height (`int`): Image height. + width (`int`): Image width. + mask (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`, *optional*): + Mask indicating which values in the keypoints, matches and matching_scores tensors are keypoint matching + information. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors. Default to `config.output_attentions`. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. Default to `config.output_hidden_states`. + + Returns: + matches (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): + For each image pair, for each keypoint in image0, the index of the keypoint in image1 that was matched + with. And for each keypoint in image1, the index of the keypoint in image0 that was matched with. + matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): + Scores of predicted matches for each image pair + all_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(1, 2, num_keypoints, + num_channels)`. + all_attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(1, 2, num_heads, num_keypoints, + num_keypoints)`. + """ + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if keypoints.shape[2] == 0: # no keypoints + shape = keypoints.shape[:-1] + return ( + keypoints.new_full(shape, -1, dtype=torch.int), + keypoints.new_zeros(shape), + all_hidden_states, + all_attentions, + ) + + batch_size, _, num_keypoints, _ = keypoints.shape + # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2) + keypoints = keypoints.reshape(batch_size * 2, num_keypoints, 2) + descriptors = descriptors.reshape(batch_size * 2, num_keypoints, self.config.hidden_size) + scores = scores.reshape(batch_size * 2, num_keypoints) + mask = mask.reshape(batch_size * 2, num_keypoints) if mask is not None else None + + # Keypoint normalization + keypoints = normalize_keypoints(keypoints, height, width) + + encoded_keypoints = self.keypoint_encoder(keypoints, scores, output_hidden_states=output_hidden_states) + + last_hidden_state = encoded_keypoints[0] + + # Keypoint MLP encoder. + descriptors = descriptors + last_hidden_state + + if mask is not None: + input_shape = descriptors.size() + extended_attention_mask = self.get_extended_attention_mask(mask, input_shape) + else: + extended_attention_mask = torch.ones((batch_size, num_keypoints), device=keypoints.device) + + # Multi-layer Transformer network. + gnn_outputs = self.gnn( + descriptors, + mask=extended_attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + descriptors = gnn_outputs[0] + + # Final MLP projection. + projected_descriptors = self.final_projection(descriptors) + + # (batch_size * 2, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim) + final_descriptors = projected_descriptors.reshape(batch_size, 2, num_keypoints, self.config.hidden_size) + final_descriptors0 = final_descriptors[:, 0] + final_descriptors1 = final_descriptors[:, 1] + + # Compute matching descriptor distance. + scores = final_descriptors0 @ final_descriptors1.transpose(1, 2) + scores = scores / self.config.hidden_size**0.5 + + if mask is not None: + mask = mask.reshape(batch_size, 2, num_keypoints) + mask0 = mask[:, 0].unsqueeze(-1).expand(-1, -1, num_keypoints) + scores = scores.masked_fill(mask0 == 0, -1e9) + + # Run the optimal transport. + scores = log_optimal_transport(scores, self.bin_score, iterations=self.config.sinkhorn_iterations) + + # Get the matches with score above "match_threshold". + max0 = scores[:, :-1, :-1].max(2) + max1 = scores[:, :-1, :-1].max(1) + indices0 = max0.indices + indices1 = max1.indices + mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) + mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) + zero = scores.new_tensor(0) + matching_scores0 = torch.where(mutual0, max0.values.exp(), zero) + matching_scores0 = torch.where(matching_scores0 > self.config.matching_threshold, matching_scores0, zero) + matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, indices1), zero) + valid0 = mutual0 & (matching_scores0 > zero) + valid1 = mutual1 & valid0.gather(1, indices1) + matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + + matches = torch.cat([matches0, matches1]).reshape(batch_size, 2, -1) + matching_scores = torch.cat([matching_scores0, matching_scores1]).reshape(batch_size, 2, -1) + + if output_hidden_states: + all_hidden_states = all_hidden_states + encoded_keypoints[1] + all_hidden_states = all_hidden_states + gnn_outputs[1] + all_hidden_states = all_hidden_states + (projected_descriptors,) + all_hidden_states = tuple( + x.reshape(batch_size, 2, num_keypoints, -1).transpose(-1, -2) for x in all_hidden_states + ) + if output_attentions: + all_attentions = all_attentions + gnn_outputs[2] + all_attentions = tuple(x.reshape(batch_size, 2, -1, num_keypoints, num_keypoints) for x in all_attentions) + + return ( + matches, + matching_scores, + all_hidden_states, + all_attentions, + ) + + @add_start_docstrings_to_model_forward(SUPERGLUE_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, KeypointMatchingOutput]: + """ + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModel + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true" + >>> image1 = Image.open(requests.get(url, stream=True).raw) + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true" + >>> image2 = Image.open(requests.get(url, stream=True).raw) + >>> images = [image1, image2] + + >>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor") + >>> model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor") + + >>> with torch.no_grad(): + >>> inputs = processor(images, return_tensors="pt") + >>> outputs = model(**inputs) + ```""" + loss = None + if labels is not None: + raise ValueError("SuperGlue is not trainable, no labels should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values.ndim != 5 or pixel_values.size(1) != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + keypoint_detections = self.keypoint_detector(pixel_values) + + keypoints, scores, descriptors, mask = keypoint_detections[:4] + keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values) + scores = scores.reshape(batch_size, 2, -1).to(pixel_values) + descriptors = descriptors.reshape(batch_size, 2, -1, self.config.hidden_size).to(pixel_values) + mask = mask.reshape(batch_size, 2, -1) + + absolute_keypoints = keypoints.clone() + absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width + absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height + + matches, matching_scores, hidden_states, attentions = self._match_image_pair( + absolute_keypoints, + descriptors, + scores, + height, + width, + mask=mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple( + v + for v in [loss, matches, matching_scores, keypoints, mask, hidden_states, attentions] + if v is not None + ) + + return KeypointMatchingOutput( + loss=loss, + matches=matches, + matching_scores=matching_scores, + keypoints=keypoints, + mask=mask, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = ["SuperGluePreTrainedModel", "SuperGlueForKeypointMatching"] diff --git a/src/transformers/models/superpoint/convert_superpoint_to_pytorch.py b/src/transformers/models/superpoint/convert_superpoint_to_pytorch.py index 18755bf4fe01..007966a0557a 100644 --- a/src/transformers/models/superpoint/convert_superpoint_to_pytorch.py +++ b/src/transformers/models/superpoint/convert_superpoint_to_pytorch.py @@ -144,7 +144,7 @@ def convert_superpoint_checkpoint(checkpoint_url, pytorch_dump_folder_path, save model.save_pretrained(pytorch_dump_folder_path) preprocessor.save_pretrained(pytorch_dump_folder_path) - model_name = "superpoint" + model_name = "magic-leap-community/superpoint" if push_to_hub: print(f"Pushing {model_name} to the hub...") model.push_to_hub(model_name) diff --git a/src/transformers/models/superpoint/image_processing_superpoint.py b/src/transformers/models/superpoint/image_processing_superpoint.py index 65309b1c1826..6fac35ecd1af 100644 --- a/src/transformers/models/superpoint/image_processing_superpoint.py +++ b/src/transformers/models/superpoint/image_processing_superpoint.py @@ -49,8 +49,12 @@ def is_grayscale( input_data_format: Optional[Union[str, ChannelDimension]] = None, ): if input_data_format == ChannelDimension.FIRST: + if image.shape[0] == 1: + return True return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]) elif input_data_format == ChannelDimension.LAST: + if image.shape[-1] == 1: + return True return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2]) @@ -75,6 +79,8 @@ def convert_to_grayscale( requires_backends(convert_to_grayscale, ["vision"]) if isinstance(image, np.ndarray): + if is_grayscale(image, input_data_format=input_data_format): + return image if input_data_format == ChannelDimension.FIRST: gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140 gray_image = np.stack([gray_image] * 3, axis=0) @@ -107,6 +113,8 @@ class SuperPointImageProcessor(BaseImageProcessor): rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess` method. + do_grayscale (`bool`, *optional*, defaults to `False`): + Whether to convert the image to grayscale. Can be overriden by `do_grayscale` in the `preprocess` method. """ model_input_names = ["pixel_values"] @@ -117,6 +125,7 @@ def __init__( size: Dict[str, int] = None, do_rescale: bool = True, rescale_factor: float = 1 / 255, + do_grayscale: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -127,6 +136,7 @@ def __init__( self.size = size self.do_rescale = do_rescale self.rescale_factor = rescale_factor + self.do_grayscale = do_grayscale def resize( self, @@ -174,6 +184,7 @@ def preprocess( size: Dict[str, int] = None, do_rescale: bool = None, rescale_factor: float = None, + do_grayscale: bool = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -197,6 +208,8 @@ def preprocess( Whether to rescale the image values between [0 - 1]. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`): + Whether to convert the image to grayscale. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. @@ -220,6 +233,7 @@ def preprocess( do_resize = do_resize if do_resize is not None else self.do_resize do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) @@ -264,10 +278,8 @@ def preprocess( # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) - # Checking if image is RGB or grayscale - for i in range(len(images)): - if not is_grayscale(images[i], input_data_format): - images[i] = convert_to_grayscale(images[i], input_data_format=input_data_format) + if do_grayscale: + images = [convert_to_grayscale(image, input_data_format=input_data_format) for image in images] images = [ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images @@ -299,7 +311,7 @@ def post_process_keypoint_detection( raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") if isinstance(target_sizes, List): - image_sizes = torch.tensor(target_sizes) + image_sizes = torch.tensor(target_sizes, device=outputs.mask.device) else: if target_sizes.shape[1] != 2: raise ValueError( diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 36e1ff2cfe65..731ba79c36f4 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8622,6 +8622,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class SuperGlueForKeypointMatching(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SuperGluePreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SuperPointForKeypointDetection(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 19cf02a4e858..d85d4c92cf57 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -611,6 +611,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class SuperGlueImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class SuperPointImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/superglue/__init__.py b/tests/models/superglue/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/superglue/test_image_processing_superglue.py b/tests/models/superglue/test_image_processing_superglue.py new file mode 100644 index 000000000000..b98d34888cfc --- /dev/null +++ b/tests/models/superglue/test_image_processing_superglue.py @@ -0,0 +1,384 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from parameterized import parameterized + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ( + ImageProcessingTestMixin, + prepare_image_inputs, +) + + +if is_torch_available(): + import numpy as np + import torch + + from transformers.models.superglue.modeling_superglue import KeypointMatchingOutput + +if is_vision_available(): + from transformers import SuperGlueImageProcessor + + +def random_array(size): + return np.random.randint(255, size=size) + + +def random_tensor(size): + return torch.rand(size) + + +class SuperGlueImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=6, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + do_grayscale=True, + ): + size = size if size is not None else {"height": 480, "width": 640} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_grayscale = do_grayscale + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_grayscale": self.do_grayscale, + } + + def expected_output_image_shape(self, images): + return 2, self.num_channels, self.size["height"], self.size["width"] + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False, pairs=True, batch_size=None): + batch_size = batch_size if batch_size is not None else self.batch_size + image_inputs = prepare_image_inputs( + batch_size=batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + if pairs: + image_inputs = [image_inputs[i : i + 2] for i in range(0, len(image_inputs), 2)] + return image_inputs + + def prepare_keypoint_matching_output(self, pixel_values): + max_number_keypoints = 50 + batch_size = len(pixel_values) + mask = torch.zeros((batch_size, 2, max_number_keypoints), dtype=torch.int) + keypoints = torch.zeros((batch_size, 2, max_number_keypoints, 2)) + matches = torch.full((batch_size, 2, max_number_keypoints), -1, dtype=torch.int) + scores = torch.zeros((batch_size, 2, max_number_keypoints)) + for i in range(batch_size): + random_number_keypoints0 = np.random.randint(10, max_number_keypoints) + random_number_keypoints1 = np.random.randint(10, max_number_keypoints) + random_number_matches = np.random.randint(5, min(random_number_keypoints0, random_number_keypoints1)) + mask[i, 0, :random_number_keypoints0] = 1 + mask[i, 1, :random_number_keypoints1] = 1 + keypoints[i, 0, :random_number_keypoints0] = torch.rand((random_number_keypoints0, 2)) + keypoints[i, 1, :random_number_keypoints1] = torch.rand((random_number_keypoints1, 2)) + random_matches_indices0 = torch.randperm(random_number_keypoints1, dtype=torch.int)[:random_number_matches] + random_matches_indices1 = torch.randperm(random_number_keypoints0, dtype=torch.int)[:random_number_matches] + matches[i, 0, random_matches_indices1] = random_matches_indices0 + matches[i, 1, random_matches_indices0] = random_matches_indices1 + scores[i, 0, random_matches_indices1] = torch.rand((random_number_matches,)) + scores[i, 1, random_matches_indices0] = torch.rand((random_number_matches,)) + return KeypointMatchingOutput(mask=mask, keypoints=keypoints, matches=matches, matching_scores=scores) + + +@require_torch +@require_vision +class SuperGlueImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = SuperGlueImageProcessor if is_vision_available() else None + + def setUp(self) -> None: + super().setUp() + self.image_processor_tester = SuperGlueImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processing(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_grayscale")) + + def test_image_processor_from_dict_with_kwargs(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 480, "width": 640}) + + image_processor = self.image_processing_class.from_dict( + self.image_processor_dict, size={"height": 42, "width": 42} + ) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + + @unittest.skip(reason="SuperPointImageProcessor is always supposed to return a grayscaled image") + def test_call_numpy_4_channels(self): + pass + + def test_number_and_format_of_images_in_input(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + + # Cases where the number of images and the format of lists in the input is correct + image_input = self.image_processor_tester.prepare_image_inputs(pairs=False, batch_size=2) + image_processed = image_processor.preprocess(image_input, return_tensors="pt") + self.assertEqual((1, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape)) + + image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=2) + image_processed = image_processor.preprocess(image_input, return_tensors="pt") + self.assertEqual((1, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape)) + + image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=4) + image_processed = image_processor.preprocess(image_input, return_tensors="pt") + self.assertEqual((2, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape)) + + image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=6) + image_processed = image_processor.preprocess(image_input, return_tensors="pt") + self.assertEqual((3, 2, 3, 480, 640), tuple(image_processed["pixel_values"].shape)) + + # Cases where the number of images or the format of lists in the input is incorrect + ## List of 4 images + image_input = self.image_processor_tester.prepare_image_inputs(pairs=False, batch_size=4) + with self.assertRaises(ValueError) as cm: + image_processor.preprocess(image_input, return_tensors="pt") + self.assertEqual(ValueError, cm.exception.__class__) + + ## List of 3 images + image_input = self.image_processor_tester.prepare_image_inputs(pairs=False, batch_size=3) + with self.assertRaises(ValueError) as cm: + image_processor.preprocess(image_input, return_tensors="pt") + self.assertEqual(ValueError, cm.exception.__class__) + + ## List of 2 pairs and 1 image + image_input = self.image_processor_tester.prepare_image_inputs(pairs=True, batch_size=3) + with self.assertRaises(ValueError) as cm: + image_processor.preprocess(image_input, return_tensors="pt") + self.assertEqual(ValueError, cm.exception.__class__) + + @parameterized.expand( + [ + ([random_array((3, 100, 200)), random_array((3, 100, 200))], (1, 2, 3, 480, 640)), + ([[random_array((3, 100, 200)), random_array((3, 100, 200))]], (1, 2, 3, 480, 640)), + ([random_tensor((3, 100, 200)), random_tensor((3, 100, 200))], (1, 2, 3, 480, 640)), + ([random_tensor((3, 100, 200)), random_tensor((3, 100, 200))], (1, 2, 3, 480, 640)), + ], + ) + def test_valid_image_shape_in_input(self, image_input, output): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + image_processed = image_processor.preprocess(image_input, return_tensors="pt") + self.assertEqual(output, tuple(image_processed["pixel_values"].shape)) + + @parameterized.expand( + [ + (random_array((3, 100, 200)),), + ([random_array((3, 100, 200))],), + (random_array((1, 3, 100, 200)),), + ([[random_array((3, 100, 200))]],), + ([[random_array((3, 100, 200))], [random_array((3, 100, 200))]],), + ([random_array((1, 3, 100, 200)), random_array((1, 3, 100, 200))],), + (random_array((1, 1, 3, 100, 200)),), + ], + ) + def test_invalid_image_shape_in_input(self, image_input): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + with self.assertRaises(ValueError) as cm: + image_processor.preprocess(image_input, return_tensors="pt") + self.assertEqual(ValueError, cm.exception.__class__) + + def test_input_images_properly_paired(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs() + pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="np") + self.assertEqual(len(pre_processed_images["pixel_values"].shape), 5) + self.assertEqual(pre_processed_images["pixel_values"].shape[1], 2) + + def test_input_not_paired_images_raises_error(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(pairs=False) + with self.assertRaises(ValueError): + image_processor.preprocess(image_inputs[0]) + + def test_input_image_properly_converted_to_grayscale(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs() + pre_processed_images = image_processor.preprocess(image_inputs) + for image_pair in pre_processed_images["pixel_values"]: + for image in image_pair: + self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...])) + + def test_call_numpy(self): + # Test overwritten because SuperGlueImageProcessor combines images by pair to feed it into SuperGlue + + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image_pair in image_pairs: + self.assertEqual(len(image_pair), 2) + + expected_batch_size = int(self.image_processor_tester.batch_size / 2) + + # Test with 2 images + encoded_images = image_processing(image_pairs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test with list of pairs + encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs) + self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape)) + + # Test without paired images + image_pairs = self.image_processor_tester.prepare_image_inputs( + equal_resolution=False, numpify=True, pairs=False + ) + with self.assertRaises(ValueError): + image_processing(image_pairs, return_tensors="pt").pixel_values + + def test_call_pil(self): + # Test overwritten because SuperGlueImageProcessor combines images by pair to feed it into SuperGlue + + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image_pair in image_pairs: + self.assertEqual(len(image_pair), 2) + + expected_batch_size = int(self.image_processor_tester.batch_size / 2) + + # Test with 2 images + encoded_images = image_processing(image_pairs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test with list of pairs + encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs) + self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape)) + + # Test without paired images + image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, pairs=False) + with self.assertRaises(ValueError): + image_processing(image_pairs, return_tensors="pt").pixel_values + + def test_call_pytorch(self): + # Test overwritten because SuperGlueImageProcessor combines images by pair to feed it into SuperGlue + + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_pairs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + for image_pair in image_pairs: + self.assertEqual(len(image_pair), 2) + + expected_batch_size = int(self.image_processor_tester.batch_size / 2) + + # Test with 2 images + encoded_images = image_processing(image_pairs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test with list of pairs + encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs) + self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape)) + + # Test without paired images + image_pairs = self.image_processor_tester.prepare_image_inputs( + equal_resolution=False, torchify=True, pairs=False + ) + with self.assertRaises(ValueError): + image_processing(image_pairs, return_tensors="pt").pixel_values + + def test_image_processor_with_list_of_two_images(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + + image_pairs = self.image_processor_tester.prepare_image_inputs( + equal_resolution=False, numpify=True, batch_size=2, pairs=False + ) + self.assertEqual(len(image_pairs), 2) + self.assertTrue(isinstance(image_pairs[0], np.ndarray)) + self.assertTrue(isinstance(image_pairs[1], np.ndarray)) + + expected_batch_size = 1 + encoded_images = image_processing(image_pairs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_pairs[0]) + self.assertEqual(tuple(encoded_images.shape), (expected_batch_size, *expected_output_image_shape)) + + @require_torch + def test_post_processing_keypoint_matching(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs() + pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt") + outputs = self.image_processor_tester.prepare_keypoint_matching_output(**pre_processed_images) + + def check_post_processed_output(post_processed_output, image_pair_size): + for post_processed_output, (image_size0, image_size1) in zip(post_processed_output, image_pair_size): + self.assertTrue("keypoints0" in post_processed_output) + self.assertTrue("keypoints1" in post_processed_output) + self.assertTrue("matching_scores" in post_processed_output) + keypoints0 = post_processed_output["keypoints0"] + keypoints1 = post_processed_output["keypoints1"] + all_below_image_size0 = torch.all(keypoints0[:, 0] <= image_size0[1]) and torch.all( + keypoints0[:, 1] <= image_size0[0] + ) + all_below_image_size1 = torch.all(keypoints1[:, 0] <= image_size1[1]) and torch.all( + keypoints1[:, 1] <= image_size1[0] + ) + all_above_zero0 = torch.all(keypoints0[:, 0] >= 0) and torch.all(keypoints0[:, 1] >= 0) + all_above_zero1 = torch.all(keypoints0[:, 0] >= 0) and torch.all(keypoints0[:, 1] >= 0) + self.assertTrue(all_below_image_size0) + self.assertTrue(all_below_image_size1) + self.assertTrue(all_above_zero0) + self.assertTrue(all_above_zero1) + all_scores_different_from_minus_one = torch.all(post_processed_output["matching_scores"] != -1) + self.assertTrue(all_scores_different_from_minus_one) + + tuple_image_sizes = [ + ((image_pair[0].size[0], image_pair[0].size[1]), (image_pair[1].size[0], image_pair[1].size[1])) + for image_pair in image_inputs + ] + tuple_post_processed_outputs = image_processor.post_process_keypoint_matching(outputs, tuple_image_sizes) + + check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes) + + tensor_image_sizes = torch.tensor( + [(image_pair[0].size, image_pair[1].size) for image_pair in image_inputs] + ).flip(2) + tensor_post_processed_outputs = image_processor.post_process_keypoint_matching(outputs, tensor_image_sizes) + + check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes) diff --git a/tests/models/superglue/test_modeling_superglue.py b/tests/models/superglue/test_modeling_superglue.py new file mode 100644 index 000000000000..0dda82ed8ad9 --- /dev/null +++ b/tests/models/superglue/test_modeling_superglue.py @@ -0,0 +1,427 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import unittest +from typing import List + +from datasets import load_dataset + +from transformers.models.superglue.configuration_superglue import SuperGlueConfig +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor + + +if is_torch_available(): + import torch + + from transformers import SuperGlueForKeypointMatching + +if is_vision_available(): + from transformers import AutoImageProcessor + + +class SuperGlueModelTester: + def __init__( + self, + parent, + batch_size=2, + image_width=80, + image_height=60, + keypoint_detector_config=None, + hidden_size: int = 64, + keypoint_encoder_sizes: List[int] = [32, 64], + gnn_layers_types: List[str] = ["self", "cross"] * 2, + num_attention_heads: int = 4, + sinkhorn_iterations: int = 100, + matching_threshold: float = 0.2, + ): + if keypoint_detector_config is None: + keypoint_detector_config = { + "encoder_hidden_sizes": [32, 64], + "decoder_hidden_size": 64, + "keypoint_decoder_dim": 65, + "descriptor_decoder_dim": 64, + "keypoint_threshold": 0.005, + "max_keypoints": 256, + "nms_radius": 4, + "border_removal_distance": 4, + } + self.parent = parent + self.batch_size = batch_size + self.image_width = image_width + self.image_height = image_height + + self.keypoint_detector_config = keypoint_detector_config + self.hidden_size = hidden_size + self.keypoint_encoder_sizes = keypoint_encoder_sizes + self.gnn_layers_types = gnn_layers_types + self.num_attention_heads = num_attention_heads + self.sinkhorn_iterations = sinkhorn_iterations + self.matching_threshold = matching_threshold + + def prepare_config_and_inputs(self): + # SuperGlue expects a grayscale image as input + pixel_values = floats_tensor([self.batch_size, 2, 3, self.image_height, self.image_width]) + config = self.get_config() + return config, pixel_values + + def get_config(self): + return SuperGlueConfig( + keypoint_detector_config=self.keypoint_detector_config, + hidden_size=self.hidden_size, + keypoint_encoder_sizes=self.keypoint_encoder_sizes, + gnn_layers_types=self.gnn_layers_types, + num_attention_heads=self.num_attention_heads, + sinkhorn_iterations=self.sinkhorn_iterations, + matching_threshold=self.matching_threshold, + ) + + def create_and_check_model(self, config, pixel_values): + model = SuperGlueForKeypointMatching(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + maximum_num_matches = result.mask.shape[-1] + self.parent.assertEqual( + result.keypoints.shape, + (self.batch_size, 2, maximum_num_matches, 2), + ) + self.parent.assertEqual( + result.matches.shape, + (self.batch_size, 2, maximum_num_matches), + ) + self.parent.assertEqual( + result.matching_scores.shape, + (self.batch_size, 2, maximum_num_matches), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class SuperGlueModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (SuperGlueForKeypointMatching,) if is_torch_available() else () + all_generative_model_classes = () if is_torch_available() else () + + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = True + + def setUp(self): + self.model_tester = SuperGlueModelTester(self) + self.config_tester = ConfigTester(self, config_class=SuperGlueConfig, has_text_modality=False, hidden_size=64) + + def test_config(self): + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + @unittest.skip(reason="SuperGlueForKeypointMatching does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching does not use feedforward chunking") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching is not trainable") + def test_training(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="SuperGlue does not output any loss term in the forward pass") + def test_retain_grad_hidden_states_attentions(self): + pass + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + maximum_num_matches = outputs.mask.shape[-1] + + hidden_states_sizes = ( + self.model_tester.keypoint_encoder_sizes + + [self.model_tester.hidden_size] + + [self.model_tester.hidden_size, self.model_tester.hidden_size * 2] + * len(self.model_tester.gnn_layers_types) + + [self.model_tester.hidden_size] * 2 + ) + + for i, hidden_states_size in enumerate(hidden_states_sizes): + self.assertListEqual( + list(hidden_states[i].shape[-2:]), + [hidden_states_size, maximum_num_matches], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_attention_outputs(self): + def check_attention_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + attentions = outputs.attentions + maximum_num_matches = outputs.mask.shape[-1] + + expected_attention_shape = [ + self.model_tester.num_attention_heads, + maximum_num_matches, + maximum_num_matches, + ] + + for i, attention in enumerate(attentions): + self.assertListEqual( + list(attention.shape[-3:]), + expected_attention_shape, + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + check_attention_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + + check_attention_output(inputs_dict, config, model_class) + + @slow + def test_model_from_pretrained(self): + from_pretrained_ids = ["magic-leap-community/superglue_indoor", "magic-leap-community/superglue_outdoor"] + for model_name in from_pretrained_ids: + model = SuperGlueForKeypointMatching.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_forward_labels_should_be_none(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + model_inputs = self._prepare_for_class(inputs_dict, model_class) + # Provide an arbitrary sized Tensor as labels to model inputs + model_inputs["labels"] = torch.rand((128, 128)) + + with self.assertRaises(ValueError) as cm: + model(**model_inputs) + self.assertEqual(ValueError, cm.exception.__class__) + + def test_batching_equivalence(self): + """ + Overwriting ModelTesterMixin.test_batching_equivalence since SuperGlue returns `matching_scores` tensors full of + zeros which causes the test to fail, because cosine_similarity of two zero tensors is 0. + Discussed here : https://github.com/huggingface/transformers/pull/29886#issuecomment-2481539481 + """ + + def recursive_check(batched_object, single_row_object, model_name, key): + if isinstance(batched_object, (list, tuple)): + for batched_object_value, single_row_object_value in zip(batched_object, single_row_object): + recursive_check(batched_object_value, single_row_object_value, model_name, key) + elif isinstance(batched_object, dict): + for batched_object_value, single_row_object_value in zip( + batched_object.values(), single_row_object.values() + ): + recursive_check(batched_object_value, single_row_object_value, model_name, key) + # do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects + elif batched_object is None or not isinstance(batched_object, torch.Tensor): + return + elif batched_object.dim() == 0: + return + else: + # indexing the first element does not always work + # e.g. models that output similarity scores of size (N, M) would need to index [0, 0] + slice_ids = [slice(0, index) for index in single_row_object.shape] + batched_row = batched_object[slice_ids] + self.assertFalse( + torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}" + ) + self.assertTrue( + (equivalence(batched_row, single_row_object)) <= 1e-03, + msg=( + f"Batched and Single row outputs are not equal in {model_name} for key={key}. " + f"Difference={equivalence(batched_row, single_row_object)}." + ), + ) + + def equivalence(tensor1, tensor2): + return torch.max(torch.abs(tensor1 - tensor2)) + + config, batched_input = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + config.output_hidden_states = True + + model_name = model_class.__name__ + batched_input_prepared = self._prepare_for_class(batched_input, model_class) + model = model_class(config).to(torch_device).eval() + + batch_size = self.model_tester.batch_size + single_row_input = {} + for key, value in batched_input_prepared.items(): + if isinstance(value, torch.Tensor) and value.shape[0] % batch_size == 0: + # e.g. musicgen has inputs of size (bs*codebooks). in most cases value.shape[0] == batch_size + single_batch_shape = value.shape[0] // batch_size + single_row_input[key] = value[:single_batch_shape] + else: + single_row_input[key] = value + + with torch.no_grad(): + model_batched_output = model(**batched_input_prepared) + model_row_output = model(**single_row_input) + + if isinstance(model_batched_output, torch.Tensor): + model_batched_output = {"model_output": model_batched_output} + model_row_output = {"model_output": model_row_output} + + for key in model_batched_output: + recursive_check(model_batched_output[key], model_row_output[key], model_name, key) + + +def prepare_imgs(): + dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train") + image1 = dataset[0]["image"] + image2 = dataset[1]["image"] + image3 = dataset[2]["image"] + return [[image1, image2], [image3, image2]] + + +@require_torch +@require_vision +class SuperGlueModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return ( + AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor") + if is_vision_available() + else None + ) + + @slow + def test_inference(self): + model = SuperGlueForKeypointMatching.from_pretrained("magic-leap-community/superglue_outdoor").to(torch_device) + preprocessor = self.default_image_processor + images = prepare_imgs() + inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item() + predicted_matches_values = outputs.matches[0, 0, :30] + predicted_matching_scores_values = outputs.matching_scores[0, 0, :20] + + expected_number_of_matches = 282 + expected_matches_values = torch.tensor([125,630,137,138,136,143,135,-1,-1,153, + 154,156,117,160,-1,149,147,152,168,-1, + 165,182,-1,190,187,188,189,112,-1,193], + device=predicted_matches_values.device) # fmt:skip + expected_matching_scores_values = torch.tensor([0.9899,0.0033,0.9897,0.9889,0.9879,0.7464,0.7109,0.0,0.0,0.9841, + 0.9889,0.9639,0.0114,0.9559,0.0,0.9735,0.8018,0.5190,0.9157,0.0], + device=predicted_matches_values.device) # fmt:skip + + """ + Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies + on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this + specific test example). The consequence of having different number of keypoints is that the number of matches + will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less + match. The matching scores will also be different, as the keypoints are different. The checks here are less + strict to account for these inconsistencies. + Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the + expected values, individually. Here, the tolerance of the number of values changing is set to 2. + + This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787) + Such CUDA inconsistencies can be found + [here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300) + """ + + self.assertTrue(abs(predicted_number_of_matches - expected_number_of_matches) < 4) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-2)) < 4 + ) + self.assertTrue(torch.sum(predicted_matches_values != expected_matches_values) < 4) diff --git a/tests/models/superpoint/test_image_processing_superpoint.py b/tests/models/superpoint/test_image_processing_superpoint.py index c2eae872004c..339ca5717e27 100644 --- a/tests/models/superpoint/test_image_processing_superpoint.py +++ b/tests/models/superpoint/test_image_processing_superpoint.py @@ -44,6 +44,7 @@ def __init__( max_resolution=400, do_resize=True, size=None, + do_grayscale=True, ): size = size if size is not None else {"height": 480, "width": 640} self.parent = parent @@ -54,11 +55,13 @@ def __init__( self.max_resolution = max_resolution self.do_resize = do_resize self.size = size + self.do_grayscale = do_grayscale def prepare_image_processor_dict(self): return { "do_resize": self.do_resize, "size": self.size, + "do_grayscale": self.do_grayscale, } def expected_output_image_shape(self, images): @@ -112,6 +115,7 @@ def test_image_processing(self): self.assertTrue(hasattr(image_processing, "size")) self.assertTrue(hasattr(image_processing, "do_rescale")) self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_grayscale")) def test_image_processor_from_dict_with_kwargs(self): image_processor = self.image_processing_class.from_dict(self.image_processor_dict)