Skip to content

Commit

Permalink
Add SuperGlue model (#29886)
Browse files Browse the repository at this point in the history
* Initial commit with template code generated by transformers-cli

* Multiple additions to SuperGlue implementation :

- Added the SuperGlueConfig
- Added the SuperGlueModel and its implementation
- Added basic weight conversion script
- Added new ImageMatchingOutput dataclass

* Few changes for SuperGlue

* Multiple changes :
- Added keypoint detection config to SuperGlueConfig
- Completed convert_superglue_to_pytorch and succesfully run inference

* Reverted unintentional change

* Multiple changes :
 - Added SuperGlue to a bunch of places
 - Divided SuperGlue into SuperGlueForImageMatching and SuperGlueModel
 - Added testing images

* Moved things in init files

* Added docs (to be finished depending on the final implementation)

* Added necessary imports and some doc

* Removed unnecessary import

* Fixed make fix-copies bug and ran it

* Deleted SuperGlueModel
Fixed convert script

* Added SuperGlueImageProcessor

* Changed SuperGlue to support batching pairs of images and modified ImageMatchingOutput in consequences

* Changed convert_superglue_to_hf.py script to experiment different ways of reading an image and seeing its impact on performances

* Added initial tests for SuperGlueImageProcessor

* Added AutoModelForImageMatching in missing places and tests

* Fixed keypoint_detector_output instructions

* Fix style

* Adapted to latest main changes

* Added integration test

* Fixed bugs to pass tests

* Added keypoints returned by keypoint detector in the output of SuperGlue

* Added doc to SuperGlue

* SuperGlue returning all attention and hidden states for a fixed number of keypoints

* Make style

* Changed SuperGlueImageProcessor tests

* Revert "SuperGlue returning all attention and hidden states for a fixed number of keypoints"
Changed tests accordingly

This reverts commit 5b3b669c

* Added back hidden_states and attentions masked outputs with tests

* Renamed ImageMatching occurences into KeypointMatching

* Changed SuperGlueImageProcessor to raise error when batch_size is not even

* Added docs and clarity to hidden state and attention grouping function

* Fixed some code and done refactoring

* Fixed typo in SuperPoint output doc

* Fixed some of the formatting and variable naming problems

* Removed useless function call

* Removed AutoModelForKeypointMatching

* Fixed SuperGlueImageProcessor to only accept paris of images

* Added more fixes to SuperGlueImageProcessor

* Simplified the batching of attention and hidden states

* Simplified stack functions

* Moved attention instructions into class

* Removed unused do_batch_norm argument

* Moved weight initialization to the proper place

* Replaced deepcopy for instantiation

* Fixed small bug

* Changed from stevenbucaille to magic-leap repo

* Renamed London Bridge images to Tower Bridge

* Fixed formatting

* Renamed remaining "london" to "tower"

* Apply suggestions from code review

Small changes in the docs

Co-authored-by: amyeroberts <[email protected]>

* Added AutoModelForKeypointMatching

* Changed images used in example

* Several changes to image_processing_superglue and style

* Fixed resample type hint

* Changed SuperGlueImageProcessor and added test case for list of 2 images

* Changed list_of_tuples implementation

* Fix in dummy objects

* Added normalize_keypoint, log_sinkhorn_iterations and log_optimal_transport docstring

* Added missing docstring

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Moved forward block at bottom

* Added docstring to forward method

* Added docstring to match_image_pair method

* Changed test_model_common_attributes to test_model_get_set_embeddings test method signature

* Removed AutoModelForKeypointMatching

* Removed image fixtures and added load_dataset

* Added padding of images in SuperGlueImageProcessor

* Cleaned up convert_superglue_to_hf script

* Added missing docs and fixed unused argument

* Fixed SuperGlueImageProcessor tests

* Transposed all hidden states from SuperGlue to reflect the standard (..., seq_len, feature_dim) shape

* Added SuperGlueForKeypointMatching back to modeling_auto

* Fixed image processor padding test

* Changed SuperGlue docs

* changes:
 - Abstraction to batch, concat and stack of inconsistent tensors
 - Changed conv1d's to linears to match standard attention implementations
 - Renamed all tensors to be tensor0 and not tensor_0 and be consistent
 - Changed match image pair to run keypoint detection on all image first, create batching tensors and then filling these tensors matches after matches
 - Various changes in docs, etc

* Changes to SuperGlueImageProcessor:
- Reworked the input image pairs checking function and added tests accordingly
- Added Copied from statements
- Added do_grayscale tag (also for SuperPointImageProcessor)
- Misc changes for better code

* Formatting changes

* Reverted conv1d to linear conversion because of numerical differences

* fix: changed some code to be more straightforward (e.g. filtering keypoints) and converted plot from opencv to matplotlib

* fix: removed unnecessary test

* chore: removed commented code and added back hidden states transpositions

* chore: changed from "inconsistent" to "ragged" function names as suggested

Co-authored-by: amyeroberts <[email protected]>

* docs: applied suggestions

Co-authored-by: amyeroberts <[email protected]>

* docs: updated to display matched output

* chore: applied suggestion for check_image_pairs_input function

Co-authored-by: amyeroberts <[email protected]>

* chore: changed check_image_pairs_input function name to validate_and_format_image_pairs and used validate_preprocess_arguments function

* tests: simplified tests for image input format and shapes

* feat: converted SuperGlue's use of Conv1d with kernel_size of 1 with Linear layers. Changed tests and conversion script accordingly

* feat: several changes to address comments

Conversion script:
- Reverted fuse batchnorm to linear conversion
- Changed all 'nn.Module' to respective SuperGlue models
- Changed conversion script to use regex mapping and match other recent scripts

Modeling SuperGlue:
- Added batching with mask and padding to attention
- Removed unnecessary concat, stack and batch ragged pairs functions
- Reverted batchnorm layer
- Renamed query, key, value and merge layers into q, k, v, out proj
- Removed Union of different Module into nn.Module in _init_weights method typehint
- Changed several method's signature to combine image0 and image1 inputs with appropriate doc changes
- Updated SuperGlue's doc with torch.no_grad()

Updated test to reflect changes in SuperGlue model

* refactor: changed validate_and_format_image_pairs function with clarity

* refactor: changed from one SuperGlueMLP class to a list of SuperGlueMLP class

* fix: fixed forgotten init weight change from last commit

* fix: fixed rebase mistake

* fix: removed leftover commented code

* fix: added typehint and changed some of arguments default values

* fix: fixed attribute default values for SuperGlueConfig

* feat: added SuperGlueImageProcessor post process keypoint matching method with tests

* fix: fixed SuperGlue attention and hidden state tuples aggregation

* chore: fixed mask optionality and reordered tensor reshapes to be cleaner

* chore: fixed docs and error message returned in validate_and_format_image_pairs function

* fix: fixed returned keypoints to be the ones that SuperPoint returns

* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue

* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue (bis)

* fix: Changed SuperGlueMultiLayerPerceptron instantiation to avoid if statement

* fix: Changed convert_superglue_to_hf script to reflect latest SuperGlue changes and got rid of nn.Modules

* WIP: implement Attention from an existing class (like BERT)

* docs: Changed docs to include more appealing matching plot

* WIP: Implement Attention

* chore: minor typehint change

* chore: changed convert superglue script by removing all classes and apply conv to linear conversion in state dict + rearrange keys to comply with changes in model's layers organisation

* Revert "Fixed typo in SuperPoint output doc"

This reverts commit 2120390.

* chore: added comments in SuperGlueImageProcessor

* chore: changed SuperGlue organization HF repo to magic-leap-community

* [run-slow] refactor: small change in layer instantiation

* [run-slow] chore: replaced remaining stevenbucaille org to magic-leap-community

* [run-slow] chore: make style

* chore: update image matching fixture dataset HF repository

* [run-slow] superglue

* tests: overwriting test_batching_equivalence

* [run-slow] superglue

* tests: changed test to cope with value changing depending on cuda version

* [run-slow] superglue

* tests: changed matching_threshold value

* [run-slow] superglue

* [run-slow] superglue

* tests: changed tests for integration

* [run-slow] superglue

* fix: Changed tensor view and permutations to match original implementation results

* fix: updated convert script and integration test to include last change in model

* fix: increase tolerance for CUDA variances

* Apply suggestions from code review

Co-authored-by: Pavel Iakubovskii <[email protected]>

* [run-slow] superglue

* chore: removed blank whitespaces

* [run-slow] superglue

* Revert SuperPoint image processor accident changes

* [run-slow] superglue

* refactor: reverted copy from BERT class

* tests: lower the tolerance in integration tests for SuperGlue

* [run-slow] superglue

* chore: set do_grayscale to False in SuperPoint and SuperGlue image processors

* [run-slow] superglue

* fix: fixed imports in SuperGlue files

* chore: changed do_grayscale SuperGlueImageProcessing default value to True

* docs: added typehint to post_process_keypoint_matching method in SuperGlueImageProcessor

* fix: set matching_threshold default value to 0.0 instead of 0.2

* feat: added matching_threshold to post_process_keypoint_matching method

* docs: update superglue.md to include matching_threshold parameter

* docs: updated SuperGlueConfig docstring for matching_threshold default value

* refactor: removed unnecessary parameters in SuperGlueConfig

* fix: changed from matching_threshold to threshold

* fix: re-revert changes to make SuperGlue attention classes copies of BERT

* [run-slow] superglue

* fix: added missing device argument in post_processing method

* [run-slow] superglue

* fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching (and docstring)

* fix: add device to image_sizes tensor instantiation

* tests: added checks on do_grayscale test

* chore: reordered and added Optional typehint to KeypointMatchingOutput

* LightGluePR suggestions:
- use `post_process_keypoint_matching` as default docs example
- add `post_process_keypoint_matching` in autodoc
- add `SuperPointConfig` import under TYPE_CHECKING condition
- format SuperGlueConfig docstring
- add device in convert_superglue_to_hf
- Fix typo
- Fix KeypointMatchingOutput docstring
- Removed unnecessary line
- Added missing SuperGlueConfig in __init__ methods

* LightGluePR suggestions:
- use batching to get keypoint detection

* refactor: processing images done in 1 for loop instead of 4

* fix: use @ instead of torch.einsum for scores computation

* style: added #fmt skip to long tensor values

* refactor: rollbacked validate_and_format_image_pairs valid and invalid case to more simple ones

* refactor: prepare_imgs

* refactor: simplified `validate_and_format_image_pairs`

* docs: fixed doc

---------

Co-authored-by: steven <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Steven Bucaille <[email protected]>
Co-authored-by: Pavel Iakubovskii <[email protected]>
  • Loading branch information
5 people authored Jan 20, 2025
1 parent 872dfbd commit abe57b6
Show file tree
Hide file tree
Showing 21 changed files with 2,777 additions and 6 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,8 @@
title: SegFormer
- local: model_doc/seggpt
title: SegGpt
- local: model_doc/superglue
title: SuperGlue
- local: model_doc/superpoint
title: SuperPoint
- local: model_doc/swiftformer
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ Flax), PyTorch, and/or TensorFlow.
| [SqueezeBERT](model_doc/squeezebert) ||||
| [StableLm](model_doc/stablelm) ||||
| [Starcoder2](model_doc/starcoder2) ||||
| [SuperGlue](model_doc/superglue) ||||
| [SuperPoint](model_doc/superpoint) ||||
| [SwiftFormer](model_doc/swiftformer) ||||
| [Swin Transformer](model_doc/swin) ||||
Expand Down
138 changes: 138 additions & 0 deletions docs/source/en/model_doc/superglue.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the MIT License; you may not use this file except in compliance with
the License.
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# SuperGlue

## Overview

The SuperGlue model was proposed in [SuperGlue: Learning Feature Matching with Graph Neural Networks](https://arxiv.org/abs/1911.11763) by Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz and Andrew Rabinovich.

This model consists of matching two sets of interest points detected in an image. Paired with the
[SuperPoint model](https://huggingface.co/magic-leap-community/superpoint), it can be used to match two images and
estimate the pose between them. This model is useful for tasks such as image matching, homography estimation, etc.

The abstract from the paper is the following:

*This paper introduces SuperGlue, a neural network that matches two sets of local features by jointly finding correspondences
and rejecting non-matchable points. Assignments are estimated by solving a differentiable optimal transport problem, whose costs
are predicted by a graph neural network. We introduce a flexible context aggregation mechanism based on attention, enabling
SuperGlue to reason about the underlying 3D scene and feature assignments jointly. Compared to traditional, hand-designed heuristics,
our technique learns priors over geometric transformations and regularities of the 3D world through end-to-end training from image
pairs. SuperGlue outperforms other learned approaches and achieves state-of-the-art results on the task of pose estimation in
challenging real-world indoor and outdoor environments. The proposed method performs matching in real-time on a modern GPU and
can be readily integrated into modern SfM or SLAM systems. The code and trained weights are publicly available at this [URL](https://github.com/magicleap/SuperGluePretrainedNetwork).*

## How to use

Here is a quick example of using the model. Since this model is an image matching model, it requires pairs of images to be matched.
The raw outputs contain the list of keypoints detected by the keypoint detector as well as the list of matches with their corresponding
matching scores.
```python
from transformers import AutoImageProcessor, AutoModel
import torch
from PIL import Image
import requests

url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
image1 = Image.open(requests.get(url_image1, stream=True).raw)
url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"
image_2 = Image.open(requests.get(url_image2, stream=True).raw)

images = [image1, image2]

processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")

inputs = processor(images, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
```

You can use the `post_process_keypoint_matching` method from the `SuperGlueImageProcessor` to get the keypoints and matches in a more readable format:

```python
image_sizes = [[(image.height, image.width) for image in images]]
outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
for i, output in enumerate(outputs):
print("For the image pair", i)
for keypoint0, keypoint1, matching_score in zip(
output["keypoints0"], output["keypoints1"], output["matching_scores"]
):
print(
f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}."
)

```

From the outputs, you can visualize the matches between the two images using the following code:
```python
import matplotlib.pyplot as plt
import numpy as np

# Create side by side image
merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3))
merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0
merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0
plt.imshow(merged_image)
plt.axis("off")

# Retrieve the keypoints and matches
output = outputs[0]
keypoints0 = output["keypoints0"]
keypoints1 = output["keypoints1"]
matching_scores = output["matching_scores"]
keypoints0_x, keypoints0_y = keypoints0[:, 0].numpy(), keypoints0[:, 1].numpy()
keypoints1_x, keypoints1_y = keypoints1[:, 0].numpy(), keypoints1[:, 1].numpy()

# Plot the matches
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, matching_scores
):
plt.plot(
[keypoint0_x, keypoint1_x + image1.width],
[keypoint0_y, keypoint1_y],
color=plt.get_cmap("RdYlGn")(matching_score.item()),
alpha=0.9,
linewidth=0.5,
)
plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
plt.scatter(keypoint1_x + image1.width, keypoint1_y, c="black", s=2)

# Save the plot
plt.savefig("matched_image.png", dpi=300, bbox_inches='tight')
plt.close()
```

![image/png](https://cdn-uploads.huggingface.co/production/uploads/632885ba1558dac67c440aa8/01ZYaLB1NL5XdA8u7yCo4.png)

This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
The original code can be found [here](https://github.com/magicleap/SuperGluePretrainedNetwork).

## SuperGlueConfig

[[autodoc]] SuperGlueConfig

## SuperGlueImageProcessor

[[autodoc]] SuperGlueImageProcessor

- preprocess

## SuperGlueForKeypointMatching

[[autodoc]] SuperGlueForKeypointMatching

- forward
- post_process_keypoint_matching
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@
],
"models.stablelm": ["StableLmConfig"],
"models.starcoder2": ["Starcoder2Config"],
"models.superglue": ["SuperGlueConfig"],
"models.superpoint": ["SuperPointConfig"],
"models.swiftformer": ["SwiftFormerConfig"],
"models.swin": ["SwinConfig"],
Expand Down Expand Up @@ -1268,6 +1269,7 @@
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
_import_structure["models.siglip"].append("SiglipImageProcessor")
_import_structure["models.superglue"].extend(["SuperGlueImageProcessor"])
_import_structure["models.superpoint"].extend(["SuperPointImageProcessor"])
_import_structure["models.swin2sr"].append("Swin2SRImageProcessor")
_import_structure["models.textnet"].extend(["TextNetImageProcessor"])
Expand Down Expand Up @@ -3545,6 +3547,12 @@
"Starcoder2PreTrainedModel",
]
)
_import_structure["models.superglue"].extend(
[
"SuperGlueForKeypointMatching",
"SuperGluePreTrainedModel",
]
)
_import_structure["models.superpoint"].extend(
[
"SuperPointForKeypointDetection",
Expand Down Expand Up @@ -5861,6 +5869,7 @@
)
from .models.stablelm import StableLmConfig
from .models.starcoder2 import Starcoder2Config
from .models.superglue import SuperGlueConfig
from .models.superpoint import SuperPointConfig
from .models.swiftformer import (
SwiftFormerConfig,
Expand Down Expand Up @@ -6361,6 +6370,7 @@
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
from .models.seggpt import SegGptImageProcessor
from .models.siglip import SiglipImageProcessor
from .models.superglue import SuperGlueImageProcessor
from .models.superpoint import SuperPointImageProcessor
from .models.swin2sr import Swin2SRImageProcessor
from .models.textnet import TextNetImageProcessor
Expand Down Expand Up @@ -8186,6 +8196,10 @@
Starcoder2Model,
Starcoder2PreTrainedModel,
)
from .models.superglue import (
SuperGlueForKeypointMatching,
SuperGluePreTrainedModel,
)
from .models.superpoint import (
SuperPointForKeypointDetection,
SuperPointPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@
squeezebert,
stablelm,
starcoder2,
superglue,
superpoint,
swiftformer,
swin,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@
("squeezebert", "SqueezeBertConfig"),
("stablelm", "StableLmConfig"),
("starcoder2", "Starcoder2Config"),
("superglue", "SuperGlueConfig"),
("superpoint", "SuperPointConfig"),
("swiftformer", "SwiftFormerConfig"),
("swin", "SwinConfig"),
Expand Down Expand Up @@ -608,6 +609,7 @@
("squeezebert", "SqueezeBERT"),
("stablelm", "StableLm"),
("starcoder2", "Starcoder2"),
("superglue", "SuperGlue"),
("superpoint", "SuperPoint"),
("swiftformer", "SwiftFormer"),
("swin", "Swin Transformer"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
("segformer", ("SegformerImageProcessor",)),
("seggpt", ("SegGptImageProcessor",)),
("siglip", ("SiglipImageProcessor",)),
("superglue", "SuperGlueImageProcessor"),
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin2sr", ("Swin2SRImageProcessor",)),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@
("squeezebert", "SqueezeBertModel"),
("stablelm", "StableLmModel"),
("starcoder2", "Starcoder2Model"),
("superglue", "SuperGlueForKeypointMatching"),
("swiftformer", "SwiftFormerModel"),
("swin", "SwinModel"),
("swin2sr", "Swin2SRModel"),
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/models/superglue/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_superglue import *
from .image_processing_superglue import *
from .modeling_superglue import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading

0 comments on commit abe57b6

Please sign in to comment.