Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SuperGlue model #29886

Merged
merged 169 commits into from
Jan 20, 2025
Merged

Add SuperGlue model #29886

merged 169 commits into from
Jan 20, 2025

Conversation

sbucaille
Copy link
Contributor

@sbucaille sbucaille commented Mar 26, 2024

What does this PR do?

Fixes #25489
This PR is the next step after implementing SuperPoint to implement image matching through keypoint matching.

Colab notebook with inference example:
https://colab.research.google.com/drive/1NhwofZFzy7IMN4irN-jC-9LZy7dx_GZ2?usp=sharing

Who can review?

@amyeroberts

@sbucaille
Copy link
Contributor Author

@amyeroberts Alright, the PR is opened, even though the branch is still a bit of a mess but the code contained in the superglue folder is more or less what it would look like in the end I guess.
Although, I have several questions / problems :

  1. I tried to follow the LlaVA example, which takes an arbitrary vision config for the vision task and implemented the same way in SuperGlue, but when I instantiate my keypoint_detector using AutoModel.from_config, SuperPoint weights are not loaded. So I need to force the instantiation using AutoModelForKeypointDetection with the keypoint_detection_config.name_or_path, is this the intended use ? (line 390 of modeling_superglue.py)
  2. What SuperGlue models should be implemented ? A SuperGlueForImageMatching containing an AutoModelForKeypointDetection with inputs in forward as regular pixel_values ? A SuperGlueForKeypointMatching without keypoint detector included but with inputs of the forward as the keypoints, scores and descriptors of 2 images ?
  3. In the example of SuperGlueForImageMatching , how should we treat pixel_values ? For now I raise an error of the batch size if not modulo 2 (the remaining code does not cover batching yet)
  4. About the SuperPoint PR, I couldn't see it in the release note, is it a mistake ? 😬

@sbucaille
Copy link
Contributor Author

sbucaille commented Apr 9, 2024

@ydshieh Hey, got a weird bug when I try to run make fix-copies in order to ensure SuperGlue is imported in all the necessary files and it crashes there :

❯ make fix-copies
python utils/check_copies.py --fix_and_overwrite
Traceback (most recent call last):
  File "/home/steven/transformers_fork/transformers/utils/check_copies.py", line 315, in split_code_into_blocks
    ).groups()[0]
AttributeError: 'NoneType' object has no attribute 'groups'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/steven/transformers_fork/transformers/utils/check_copies.py", line 1218, in <module>
    check_copies(args.fix_and_overwrite, args.file)
  File "/home/steven/transformers_fork/transformers/utils/check_copies.py", line 852, in check_copies
    new_diffs = is_copy_consistent(filename, overwrite, buffer)
  File "/home/steven/transformers_fork/transformers/utils/check_copies.py", line 675, in is_copy_consistent
    target_lines, theoretical_code, theoretical_code_splits = find_code_and_splits(
  File "/home/steven/transformers_fork/transformers/utils/check_copies.py", line 529, in find_code_and_splits
    code_splits = split_code_into_blocks(
  File "/home/steven/transformers_fork/transformers/utils/check_copies.py", line 319, in split_code_into_blocks
    raise ValueError(
ValueError: Tried to split a class or function. It did not work. Error comes from line -1: 
```
        help="Name of the model you'd like to convert.",
    )
    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
    )
    parser.add_argument(
        "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
    )

    args = parser.parse_args()
    convert_dinov2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
```

make: *** [Makefile:84: fix-copies] Error 1

This seems to come from dinov2/convert_dinov2_to_hf.py which makes no sense. Any idea ?

@amyeroberts
Copy link
Collaborator

@sbucaille Apologies for the delay in my response. I've answered your initial questions below - let me know if anything isn't clear

I tried to follow the LlaVA example, which takes an arbitrary vision config for the vision task and implemented the same way in SuperGlue, but when I instantiate my keypoint_detector using AutoModel.from_config, SuperPoint weights are not loaded. So I need to force the instantiation using AutoModelForKeypointDetection with the keypoint_detection_config.name_or_path, is this the intended use ? (line 390 of modeling_superglue.py)

This shouldn't be necessary. Doing:

self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)

is the correct way to go. If the weights are all newly initialized, then first thing I'd suspect is that the checkpoint being used doesn't contain the SuperPoint weights.

What SuperGlue models should be implemented ? A SuperGlueForImageMatching containing an AutoModelForKeypointDetection with inputs in forward as regular pixel_values ? A SuperGlueForKeypointMatching without keypoint detector included but with inputs of the forward as the keypoints, scores and descriptors of 2 images ?

SuperGlueForImageMatching taking pixel_values. I'll think about the name to see if there's something better we can use. "ImageMatching" is a bit ambiguous.

In the example of SuperGlueForImageMatching , how should we treat pixel_values ? For now I raise an error of the batch size if not modulo 2 (the remaining code does not cover batching yet)

In this case, are the images to be pair interleaved? e.g. the sequence goes [image_1_a, image_1_b, image_2_a, image_2_b, ...]

I think I would do something more similar to sentence similarity for text models:

  • The image processor take a pair of images or a list of pair of images
  • The images are processed and returned as a tensor of shape (B, 2, C, H, W)

About the SuperPoint PR, I couldn't see it in the release note, is it a mistake ? 😬

This is just an oversight, sorry, as the model as added last minute as a patch after the branch was already cut for release. I'll update the notes.

@sbucaille
Copy link
Contributor Author

@ydshieh Nevermind, found the problem, it came from a left ouf # Copied from transformers.models.dinov2.convert_dinov2_to_hf above a function, sorry for the inconvenience.

@amyeroberts Everything is clear, I fixed the embedded keypoint_detector state dict loading in the convert_superglue_to_hf.py script. So we agree that saving a SuperGlueForImageMatching weights in a new Hub repo necessarily means we also save the keypoint_detector weights with it right ? Is this how it is implemented in LlaVa ?

I'll implement the new SuperGlueImageProcessor later, but I was trying out inferences, I noticed a difference in terms of result with the way we read an image.
In SuperPoint it was not that noticeable, but for example, with the two images I've added in this PR, using the SuperPointImageProcessor (which uses PIL.Image.open()) ends up having SuperGlue matching 144 points where using the original authors code for reading (see below in the post), ends up having SuperGlue matching 175 points (I reach this number whether I run their original code or when I added it in the convert_superglue_to_hf.py script). Since their implementation uses OpenCV instead of PIL, should we change the SuperPointImageProcessor ? What is the convention in transformers for image reading libraries ? This difference has a negative snowball effect on the performance, although I haven't tried on many examples.

The original read_image function from authors :

def frame2tensor(frame, device):
    return torch.from_numpy(frame/255.).float()[None, None].to(device)


def read_image(path, device, resize, rotation, resize_float):
    image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    if image is None:
        return None, None, None
    w, h = image.shape[1], image.shape[0]
    w_new, h_new = process_resize(w, h, resize)
    scales = (float(w) / float(w_new), float(h) / float(h_new))

    if resize_float:
        image = cv2.resize(image.astype('float32'), (w_new, h_new))
    else:
        image = cv2.resize(image, (w_new, h_new)).astype('float32')

    if rotation != 0:
        image = np.rot90(image, k=rotation)
        if rotation % 2:
            scales = scales[::-1]

    inp = frame2tensor(image, device)
    return image, inp, scales

@amyeroberts
Copy link
Collaborator

So we agree that saving a SuperGlueForImageMatching weights in a new Hub repo necessarily means we also save the keypoint_detector weights with it right ? Is this how it is implemented in LlaVa ?

Yep, that's how it's done in Llava and other composite models. You can check this by inspecting the safetensor weight names on the hub for a checkpoint: https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf/tree/main (Click on this symbol: image)

In SuperPoint it was not that noticeable, but for example, with the two images I've added in this PR, using the SuperPointImageProcessor (which uses PIL.Image.open())

None of our image processors should be using PIL.Image.open. They should take already opened images. I can't see where it's used in the image processor?

The image processors accept PIL.Image.Image, torch.tensor, jax.array, tf.tensor or np.array, but it's not responsible for loading / opening images. If for testing the port you want to use cv and then convert the images to one of these formats, that should be OK.

It would be interesting to know what the difference is between the two images after loading from PIL vs cv2

@sbucaille
Copy link
Contributor Author

sbucaille commented Apr 11, 2024

@amyeroberts Indeed there is no such thing of reading an image in the image processors, I was mistaken.

The image processors accept PIL.Image.Image, torch.tensor, jax.array, tf.tensor or np.array, but it's not responsible for loading / opening images. If for testing the port you want to use cv and then convert the images to one of these formats, that should be OK.
It would be interesting to know what the difference is between the two images after loading from PIL vs cv2

About that, I've modified the convert_superglue_to_hf.py script and added multiple ways of reading the image. After some experiment, it looks like reading the image with PIL.Image.open(), cv2.imread(..., cv2.IMREAD_GRAYSCALE) or cv2.imread(..., cv2.IMREAD_COLOR) (with conversion from BGR to RGB) results in the same number of matching. I haven't looked at the images themselves but the only difference now resides in the way images are resized, in the original code it is done with cv2 whereas in transformers it is done with PIL by default. As I said earlier, I haven't tried on many images so maybe this pair of images is indeed impacted in terms of performance, maybe in other cases it won't. How strict is the policy regarding the performance of models being replicated when implemented in transformers ?

On additional notes :

  • I've added the SuperGlueImageProcessor. For obvious reasons, I made this object instantiate a SuperPointImageProcessor as the processing in terms of grayscale is exactly the same. SuperGlueImageProcessor will only make sure there is a pair number of images and pair them interleaved as you suggested. I was also thinking about a "pairing mode" where the user could give an arbitrary number of images and the processor would group them based on the mode. For example, with combinations, having [im0, im1, im2] would result in [(im0, im1), (im0, im2), (im1, im2)] ?
  • I've implemented batching in SuperGlue similarly to SuperPoint with a masked output.
  • How should hidden_states be handled in SuperGlue ? Should it be SuperPoint hidden_states or something referring the keypoint_encoder or the gnn ?
  • Also, as a suggestion for the naming, ImageKeypointMatching could be better than just ImageMatching ? And maybe should we rename KeypointDetection references to ImageKeypointDetection ?

@sbucaille
Copy link
Contributor Author

sbucaille commented Apr 15, 2024

@amyeroberts Me again, I had time to write some tests, but I ended up with a very verbose implementation of the forward method. The reason is because in order to cover all tests about hidden states, we need to output 2 variables (last_hidden_state, hidden_states), but these variables can't be stacked easily as SuperGlue states are sparse tensors and their last dimension depends on the number of keypoints, so we need to store these states for each image, hence we have 4 variables throughout the method (last_hidden_state_0, last_hidden_state_1, hidden_states_0, hidden_states_1). In the end, we have 2 variables but these need to be padded tensors like matches and matching_scores.
This is just the beginning as SuperGlue also includes attention, so these tensors will also need to be taken care of. But considering the current state of the forward method, I preferred to keep it like that and get your opinion on this.

On different tests I have issues :

  • Every tests related to torchscript fail because it is allocating too much memory in my GPU when reaching einsum instructions, is there known issues of einsum with these torchscript tests ?
>           return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
E           torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 82.88 GiB. GPU 0 has a total capacty of 15.99 GiB of which 12.28 GiB is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 1.09 GiB is allocated by PyTorch, and 85.38 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
  • The last test not passing is about initialization, where weight values are not getting the proper value, but the following statement is a bit odd :
for name, param in model.named_parameters():
                if param.requires_grad:
>                   self.assertIn(
                        ((param.data.mean() * 1e9).round() / 1e9).item(),
                        [0.0, 1.0],
                        msg=f"Parameter {name} of model {model_class} seems not properly initialized",
                    )
E                   AssertionError: 0.06238596513867378 not found in [0.0, 1.0] : Parameter keypoint_encoder.encoder.layers.0.weight of model <class 'transformers.models.superglue.modeling_superglue.SuperGlueForImageMatching'> seems not properly initialized

Last test to write is the one related to the pretrained model and checking whether the model outputs the proper values.
Also some docs missing here and there.

EDIT : problem fixed

@sbucaille sbucaille force-pushed the add_superglue branch 2 times, most recently from 79dcc0c to b68ae29 Compare May 11, 2024 21:05
@sbucaille
Copy link
Contributor Author

sbucaille commented May 11, 2024

Hey @amyeroberts
Following up on the PR, I adapted the code to the recent changes of the main branch and noticed several new tests that are not passing on my side (test_save_load_low_cpu... ones) with the following error :

        for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
>           self.assertEquals(p1.data.ne(p2.data).sum(), 0)
E           RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

tests/test_modeling_common.py:504: RuntimeError

Is it a known problem from the main branch (I tried with the superpoint tests and ended with the same result) or something on my part ?

EDIT : problem fixed

@sbucaille
Copy link
Contributor Author

@amyeroberts Hey, you can disregard the last two posts I made as I found the origin of the problems.
Although I have a question regarding returning hidden_states and attentions. In SuperPoint, hidden_states are returned up to the point before we have sparse tensors with an arbitrary number of keypoints detected. In SuperGlue, everything computed after the keypoint detector has the same arbitrary dimension depending on the number of keypoints detected. I worked on this commit for several hours before realizing it 😅. And the tests actually pass since we make sure to have a fixed number of keypoints returned, so all hidden_states and attentions are properly concatenated together, but adding output_hidden_states=True to an inference with different number of keypoints per image would crash. But adapting this code for arbitrary number of keypoints will lead to having a lot of lines of code dedicated for creating a mask tensor and filling it with the tensors (which represent already a big portion of the code just for the matches and matching scores).
So, now a question worth asking is : is it relevant to returns such states from SuperGlue ? I would say it makes sense of the model would act just like a backbone but it is far from it in my opinion.
If yes, then I can always implement it, no worries. If no then I'll revert the commit and we can consider the model as fully implemented and we could start the review process and make a final pass on it.

@amyeroberts
Copy link
Collaborator

@sbucaille Apologies for the delay in my response on this PR.

If I've understood correctly, you're asking about whether we should return the hidden states from the SuperPoint model when calling SuperGlue, is that right?

If so, I'd say no, we don't need to return them and I agree it can act as if it were a backbone. As, technically, any keypoint detection model can feed into super glue, we can't make the assumption about all of them having the same kind of hidden states.

@sbucaille
Copy link
Contributor Author

@amyeroberts Hey ! No worries at all !!
Part of the question was indeed whether SuperGlue should return SuperPoint hidden_states yes. But the other part of the question regarded SuperGlue's hidden states itself, should it return hidden states and/or attention as it involves some MLP's and attention layers ? I assumed no and reverted my changes and SuperGlue does not return anything other than the matches, matching_scores and keypoints along with the mask.
Other previously points need to be addressed :

  • Regarding the SuperGlueImageProcessor, I currently copied the behavior from the SuperPointImageProcessor, should that rather be instantiated in the SuperGlueImageProcessor ?
  • Still on the SuperGlueImageProcessor, I currently batched images as you suggested : taking a list of images [image0, image1, ..., image{n-1}, image{n}] into [[image0, image1], ..., [image{n-1}, image{n}]] as a [B, 2, C, H, W] tensor, BUT if n is odd then the last image is doubled so we still have a pair of images and code is not crashing.
  • For the concern you had about the naming for the model and output with ImageMatching, I suggest ImageKeypointMatching. And maybe we could rename KeypointDetection references to ImageKeypointDetection ?

All tests passed, should we move onto a first review ?

@amyeroberts
Copy link
Collaborator

@sbucaille Ah, OK, I see. Regarding returning hidden_states and attentions, yes, the model should return these. It's fine if the dimension here varies depending on the number of points detected, as long as not all of the dimensions vary (if that makes sense?). That is, for each "block" of the model, there should be an associated hidden_states and attentions tensor which is returned if output_hidden_states and output_attentions are True.

For the other points:

  • Regarding the SuperGlueImageProcessor, I currently copied the behavior from the SuperPointImageProcessor, should that rather be instantiated in the SuperGlueImageProcessor ?

Could you clarify this a bit? Specifically instantiation of what and where?

  • Still on the SuperGlueImageProcessor, I currently batched images as you suggested : taking a list of images [image0, image1, ..., image{n-1}, image{n}] into [[image0, image1], ..., [image{n-1}, image{n}]] as a [B, 2, C, H, W] tensor, BUT if n is odd then the last image is doubled so we still have a pair of images and code is not crashing.

Is there any valid input to SuperGlue which involves one image? If not, then we can raise an exception when the image processor is called if the number of images isn't even

  • For the concern you had about the naming for the model and output with ImageMatching, I suggest ImageKeypointMatching. And maybe we could rename KeypointDetection references to ImageKeypointDetection ?

ImageKeypointMatching sounds good to me! Or just KeypointMatching

We could, although renaming models should be done lightly as this would be a breaking change. I don't think the Image prefix adds much.

@sbucaille
Copy link
Contributor Author

Hey @amyeroberts,

@sbucaille Ah, OK, I see. Regarding returning hidden_states and attentions, yes, the model should return these. It's fine if the dimension here varies depending on the number of points detected, as long as not all of the dimensions vary (if that makes sense?). That is, for each "block" of the model, there should be an associated hidden_states and attentions tensor which is returned if output_hidden_states and output_attentions are True.

Alright I made the changes to output the hidden states and attentions, in hidden states only the last dimension is "masked" and equals the highest number of keypoints in all the images, same for attentions for the last and 2nd to last.

  • Regarding the SuperGlueImageProcessor, I currently copied the behavior from the SuperPointImageProcessor, should that rather be instantiated in the SuperGlueImageProcessor ?

Could you clarify this a bit? Specifically instantiation of what and where?

Since SuperGlue uses SuperPoint as its keypoint detector, the SuperGlueImageProcessor needs to process images such that SuperPoint can process them, so for now I copied the part where SuperPointImageProcessor turns the images into grey scaled tensors, and added the logic of batching images into pairs after that, my question is should SuperGlueImageProcessor use the SuperPointImageProcessor as a first "internal" processor so that images complies with SuperPoint requirements and then add the image pair batching logic after or should I keep it like that ? (you can find the code here)

  • Still on the SuperGlueImageProcessor, I currently batched images as you suggested : taking a list of images [image0, image1, ..., image{n-1}, image{n}] into [[image0, image1], ..., [image{n-1}, image{n}]] as a [B, 2, C, H, W] tensor, BUT if n is odd then the last image is doubled so we still have a pair of images and code is not crashing.

Is there any valid input to SuperGlue which involves one image? If not, then we can raise an exception when the image processor is called if the number of images isn't even

Made these changes and changed tests accordingly

  • For the concern you had about the naming for the model and output with ImageMatching, I suggest ImageKeypointMatching. And maybe we could rename KeypointDetection references to ImageKeypointDetection ?

ImageKeypointMatching sounds good to me! Or just KeypointMatching

We could, although renaming models should be done lightly as this would be a breaking change. I don't think the Image prefix adds much.

I've renamed ImageMatching occurences into KeypointMatching

I need to make a last pass on the code to check for quality, refactoring, naming and commenting problems in the code but overall I'd say we have everything

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on adding this model!

I've just done an initial first pass - so there will be some things I'll revisit in a more in-depth review. I might have missed something, but my first thought when looking at the structure, is there's a lot of the code which takes both of the images, are returns e.g. their respective scores, but the scores aren't dependent on one another e.g. in SuperGlueAttentionalGNN. I think it would be better to have these layers just take one image, and then call this layer twice and combine as needed in the final stages.

src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
@sbucaille
Copy link
Contributor Author

sbucaille commented May 26, 2024

Hey @amyeroberts ,
I addressed the review comments.
I am sorry for the offence regarding the naming of the images of the Tower Bridge, but in my defense, I just copied it from the original repo 😬

@amyeroberts
Copy link
Collaborator

I am sorry for the offence regarding the naming of the images of the Tower Bridge, but in my defense, I just copied it from the original repo 😬

@sbucaille No offence taken! It's a very common mistake 😄

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great - thanks for the continued work on this! A few general structural things, but we're pretty close to being ready 🤗

src/transformers/models/auto/modeling_auto.py Outdated Show resolved Hide resolved
docs/source/en/model_doc/superglue.md Outdated Show resolved Hide resolved
src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
matches_mask[i, 1, : _matches_1.shape[1]] = 1
keypoints[i, 0, : _keypoints_0.shape[1], :] = _keypoints_0
keypoints[i, 1, : _keypoints_1.shape[1], :] = _keypoints_1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding batching of the outputs, we don't want to batch together all of the attentions and hidden states into one big tensor. The pattern for other models is a tuple of tensors is returned, with each element in the tensor representing a layer or block of the model. Sorry if I wasn't clear about this earlier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh but it is not the case, we don't have a big tensor for all hidden_states and a big tensor for all attentions, it is still a tuple of hidden states and a tuple of attentions. What we are batching here is the multiple tuples of hidden states and attentions from the different image matching, together

@sbucaille
Copy link
Contributor Author

Hey @amyeroberts
Addressed most of the issues, answered some of your confusions. There are still a few points to be addressed 😄

@sbucaille
Copy link
Contributor Author

@qubvel done !

@sbucaille
Copy link
Contributor Author

sbucaille commented Dec 2, 2024

Hi @qubvel , @ArthurZucker ,
Just pushed a small commit about a tiny bug I found and just fixed in the post processing method. The program would crash as outputs may be on a difference device from CPU where image_sizes are instantiated as tensors.

@qubvel
Copy link
Member

qubvel commented Dec 3, 2024

Thanks, @sbucaille! Arthur is on/off this week, but hopefully, he will be able to review it then. Thanks for your patience 🤗

- use `post_process_keypoint_matching` as default docs example
- add `post_process_keypoint_matching` in autodoc
- add `SuperPointConfig` import under TYPE_CHECKING condition
- format SuperGlueConfig docstring
- add device in convert_superglue_to_hf
- Fix typo
- Fix KeypointMatchingOutput docstring
- Removed unnecessary line
- Added missing SuperGlueConfig in __init__ methods
- use batching to get keypoint detection
@sbucaille
Copy link
Contributor Author

Hey @ArthurZucker ! Gentle 🎅 bump here, happy holidays !

@sbucaille
Copy link
Contributor Author

Hey @qubvel ! Wish you all the best for this new year ! Heard some good news about vision PR merges so here is a gentle bump ! 😬

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small nits! Sorry that this took so long 🤗
Let's make sure the model / feature (keypoint matching) is easy to use and thus add basic functionnalities for it!

docs/source/en/model_doc/superglue.md Show resolved Hide resolved
Comment on lines +81 to +115
import matplotlib.pyplot as plt
import numpy as np

# Create side by side image
merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3))
merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0
merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0
plt.imshow(merged_image)
plt.axis("off")

# Retrieve the keypoints and matches
output = outputs[0]
keypoints0 = output["keypoints0"]
keypoints1 = output["keypoints1"]
matching_scores = output["matching_scores"]
keypoints0_x, keypoints0_y = keypoints0[:, 0].numpy(), keypoints0[:, 1].numpy()
keypoints1_x, keypoints1_y = keypoints1[:, 0].numpy(), keypoints1[:, 1].numpy()

# Plot the matches
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, matching_scores
):
plt.plot(
[keypoint0_x, keypoint1_x + image1.width],
[keypoint0_y, keypoint1_y],
color=plt.get_cmap("RdYlGn")(matching_score.item()),
alpha=0.9,
linewidth=0.5,
)
plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
plt.scatter(keypoint1_x + image1.width, keypoint1_y, c="black", s=2)

# Save the plot
plt.savefig("matched_image.png", dpi=300, bbox_inches='tight')
plt.close()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think it would make sense to add this to the image processor / processor ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like a plot_keypoint_matching(images, keypoint_matching_output, path) method ? or just as a docstring ?

Copy link
Collaborator

@ArthurZucker ArthurZucker Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah a method sounds good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought a bit about this but I think it depends on whether you want to put visualization forward in the library or not. Here in this example we assume only a pair of images, but as a method in the processor, should it handle multiple pairs like other methods ? If so, should we visualize the pairs individually / all together ? In terms of plotting, should we force the template we have here or allow some customization ?
On the other hand, I don't know your policy about that, but on this SuperPoint's PR, another contributor took the opportunity of us introducing the new keypoint detection task to implement visualization in roboflow/supervision (PR still under work in progress as it appears), maybe it could also be the case for keypoint matching ?

Copy link
Member

@qubvel qubvel Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have matplotlib dependency? otherwise, I would better just provide snippets in docs and model card (as we have right now)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be considered as resolved ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we could have a soft dependency as well



def verify_model_outputs(model, model_name, device):
from tests.models.superglue.test_modeling_superglue import prepare_imgs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to just copy the function as conversion files are supposed to be runnable and usable alone !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this issue

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sbucaille should we remove the import then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's push this change and merge 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import got removed in a following commit yep

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, sorry, looks like was looking diff for older commit

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean conversion script! 🤗

src/transformers/models/superglue/modeling_superglue.py Outdated Show resolved Hide resolved
tests/models/superglue/test_image_processing_superglue.py Outdated Show resolved Hide resolved
tests/models/superglue/test_modeling_superglue.py Outdated Show resolved Hide resolved
@ArthurZucker
Copy link
Collaborator

@sbucaille happy new year as well 🤗

@sbucaille
Copy link
Contributor Author

Hi @ArthurZucker , happy new year to you too and thanks for the review !
I've addressed and replied to your comments.
On the validate_and_format_image_pairs function being too big, I've removed cases where we have 4D or 5D arrays and added parameterized tests.
For the other comments, let me know what you think !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗 let's just solve the last nits and merge!

@qubvel qubvel changed the title Implement SuperGlue model Add SuperGlue model Jan 20, 2025
@qubvel qubvel merged commit abe57b6 into huggingface:main Jan 20, 2025
24 checks passed
@qubvel
Copy link
Member

qubvel commented Jan 20, 2025

Congratulations @sbucaille on the model merged 🎉 🎉 🎉 Fantastic work! Thanks for iterating so many times to follow our standards 🤗

@sbucaille sbucaille deleted the add_superglue branch January 20, 2025 10:54
bursteratom pushed a commit to bursteratom/transformers that referenced this pull request Jan 31, 2025
* Initial commit with template code generated by transformers-cli

* Multiple additions to SuperGlue implementation :

- Added the SuperGlueConfig
- Added the SuperGlueModel and its implementation
- Added basic weight conversion script
- Added new ImageMatchingOutput dataclass

* Few changes for SuperGlue

* Multiple changes :
- Added keypoint detection config to SuperGlueConfig
- Completed convert_superglue_to_pytorch and succesfully run inference

* Reverted unintentional change

* Multiple changes :
 - Added SuperGlue to a bunch of places
 - Divided SuperGlue into SuperGlueForImageMatching and SuperGlueModel
 - Added testing images

* Moved things in init files

* Added docs (to be finished depending on the final implementation)

* Added necessary imports and some doc

* Removed unnecessary import

* Fixed make fix-copies bug and ran it

* Deleted SuperGlueModel
Fixed convert script

* Added SuperGlueImageProcessor

* Changed SuperGlue to support batching pairs of images and modified ImageMatchingOutput in consequences

* Changed convert_superglue_to_hf.py script to experiment different ways of reading an image and seeing its impact on performances

* Added initial tests for SuperGlueImageProcessor

* Added AutoModelForImageMatching in missing places and tests

* Fixed keypoint_detector_output instructions

* Fix style

* Adapted to latest main changes

* Added integration test

* Fixed bugs to pass tests

* Added keypoints returned by keypoint detector in the output of SuperGlue

* Added doc to SuperGlue

* SuperGlue returning all attention and hidden states for a fixed number of keypoints

* Make style

* Changed SuperGlueImageProcessor tests

* Revert "SuperGlue returning all attention and hidden states for a fixed number of keypoints"
Changed tests accordingly

This reverts commit 5b3b669c

* Added back hidden_states and attentions masked outputs with tests

* Renamed ImageMatching occurences into KeypointMatching

* Changed SuperGlueImageProcessor to raise error when batch_size is not even

* Added docs and clarity to hidden state and attention grouping function

* Fixed some code and done refactoring

* Fixed typo in SuperPoint output doc

* Fixed some of the formatting and variable naming problems

* Removed useless function call

* Removed AutoModelForKeypointMatching

* Fixed SuperGlueImageProcessor to only accept paris of images

* Added more fixes to SuperGlueImageProcessor

* Simplified the batching of attention and hidden states

* Simplified stack functions

* Moved attention instructions into class

* Removed unused do_batch_norm argument

* Moved weight initialization to the proper place

* Replaced deepcopy for instantiation

* Fixed small bug

* Changed from stevenbucaille to magic-leap repo

* Renamed London Bridge images to Tower Bridge

* Fixed formatting

* Renamed remaining "london" to "tower"

* Apply suggestions from code review

Small changes in the docs

Co-authored-by: amyeroberts <[email protected]>

* Added AutoModelForKeypointMatching

* Changed images used in example

* Several changes to image_processing_superglue and style

* Fixed resample type hint

* Changed SuperGlueImageProcessor and added test case for list of 2 images

* Changed list_of_tuples implementation

* Fix in dummy objects

* Added normalize_keypoint, log_sinkhorn_iterations and log_optimal_transport docstring

* Added missing docstring

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Moved forward block at bottom

* Added docstring to forward method

* Added docstring to match_image_pair method

* Changed test_model_common_attributes to test_model_get_set_embeddings test method signature

* Removed AutoModelForKeypointMatching

* Removed image fixtures and added load_dataset

* Added padding of images in SuperGlueImageProcessor

* Cleaned up convert_superglue_to_hf script

* Added missing docs and fixed unused argument

* Fixed SuperGlueImageProcessor tests

* Transposed all hidden states from SuperGlue to reflect the standard (..., seq_len, feature_dim) shape

* Added SuperGlueForKeypointMatching back to modeling_auto

* Fixed image processor padding test

* Changed SuperGlue docs

* changes:
 - Abstraction to batch, concat and stack of inconsistent tensors
 - Changed conv1d's to linears to match standard attention implementations
 - Renamed all tensors to be tensor0 and not tensor_0 and be consistent
 - Changed match image pair to run keypoint detection on all image first, create batching tensors and then filling these tensors matches after matches
 - Various changes in docs, etc

* Changes to SuperGlueImageProcessor:
- Reworked the input image pairs checking function and added tests accordingly
- Added Copied from statements
- Added do_grayscale tag (also for SuperPointImageProcessor)
- Misc changes for better code

* Formatting changes

* Reverted conv1d to linear conversion because of numerical differences

* fix: changed some code to be more straightforward (e.g. filtering keypoints) and converted plot from opencv to matplotlib

* fix: removed unnecessary test

* chore: removed commented code and added back hidden states transpositions

* chore: changed from "inconsistent" to "ragged" function names as suggested

Co-authored-by: amyeroberts <[email protected]>

* docs: applied suggestions

Co-authored-by: amyeroberts <[email protected]>

* docs: updated to display matched output

* chore: applied suggestion for check_image_pairs_input function

Co-authored-by: amyeroberts <[email protected]>

* chore: changed check_image_pairs_input function name to validate_and_format_image_pairs and used validate_preprocess_arguments function

* tests: simplified tests for image input format and shapes

* feat: converted SuperGlue's use of Conv1d with kernel_size of 1 with Linear layers. Changed tests and conversion script accordingly

* feat: several changes to address comments

Conversion script:
- Reverted fuse batchnorm to linear conversion
- Changed all 'nn.Module' to respective SuperGlue models
- Changed conversion script to use regex mapping and match other recent scripts

Modeling SuperGlue:
- Added batching with mask and padding to attention
- Removed unnecessary concat, stack and batch ragged pairs functions
- Reverted batchnorm layer
- Renamed query, key, value and merge layers into q, k, v, out proj
- Removed Union of different Module into nn.Module in _init_weights method typehint
- Changed several method's signature to combine image0 and image1 inputs with appropriate doc changes
- Updated SuperGlue's doc with torch.no_grad()

Updated test to reflect changes in SuperGlue model

* refactor: changed validate_and_format_image_pairs function with clarity

* refactor: changed from one SuperGlueMLP class to a list of SuperGlueMLP class

* fix: fixed forgotten init weight change from last commit

* fix: fixed rebase mistake

* fix: removed leftover commented code

* fix: added typehint and changed some of arguments default values

* fix: fixed attribute default values for SuperGlueConfig

* feat: added SuperGlueImageProcessor post process keypoint matching method with tests

* fix: fixed SuperGlue attention and hidden state tuples aggregation

* chore: fixed mask optionality and reordered tensor reshapes to be cleaner

* chore: fixed docs and error message returned in validate_and_format_image_pairs function

* fix: fixed returned keypoints to be the ones that SuperPoint returns

* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue

* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue (bis)

* fix: Changed SuperGlueMultiLayerPerceptron instantiation to avoid if statement

* fix: Changed convert_superglue_to_hf script to reflect latest SuperGlue changes and got rid of nn.Modules

* WIP: implement Attention from an existing class (like BERT)

* docs: Changed docs to include more appealing matching plot

* WIP: Implement Attention

* chore: minor typehint change

* chore: changed convert superglue script by removing all classes and apply conv to linear conversion in state dict + rearrange keys to comply with changes in model's layers organisation

* Revert "Fixed typo in SuperPoint output doc"

This reverts commit 2120390.

* chore: added comments in SuperGlueImageProcessor

* chore: changed SuperGlue organization HF repo to magic-leap-community

* [run-slow] refactor: small change in layer instantiation

* [run-slow] chore: replaced remaining stevenbucaille org to magic-leap-community

* [run-slow] chore: make style

* chore: update image matching fixture dataset HF repository

* [run-slow] superglue

* tests: overwriting test_batching_equivalence

* [run-slow] superglue

* tests: changed test to cope with value changing depending on cuda version

* [run-slow] superglue

* tests: changed matching_threshold value

* [run-slow] superglue

* [run-slow] superglue

* tests: changed tests for integration

* [run-slow] superglue

* fix: Changed tensor view and permutations to match original implementation results

* fix: updated convert script and integration test to include last change in model

* fix: increase tolerance for CUDA variances

* Apply suggestions from code review

Co-authored-by: Pavel Iakubovskii <[email protected]>

* [run-slow] superglue

* chore: removed blank whitespaces

* [run-slow] superglue

* Revert SuperPoint image processor accident changes

* [run-slow] superglue

* refactor: reverted copy from BERT class

* tests: lower the tolerance in integration tests for SuperGlue

* [run-slow] superglue

* chore: set do_grayscale to False in SuperPoint and SuperGlue image processors

* [run-slow] superglue

* fix: fixed imports in SuperGlue files

* chore: changed do_grayscale SuperGlueImageProcessing default value to True

* docs: added typehint to post_process_keypoint_matching method in SuperGlueImageProcessor

* fix: set matching_threshold default value to 0.0 instead of 0.2

* feat: added matching_threshold to post_process_keypoint_matching method

* docs: update superglue.md to include matching_threshold parameter

* docs: updated SuperGlueConfig docstring for matching_threshold default value

* refactor: removed unnecessary parameters in SuperGlueConfig

* fix: changed from matching_threshold to threshold

* fix: re-revert changes to make SuperGlue attention classes copies of BERT

* [run-slow] superglue

* fix: added missing device argument in post_processing method

* [run-slow] superglue

* fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching (and docstring)

* fix: add device to image_sizes tensor instantiation

* tests: added checks on do_grayscale test

* chore: reordered and added Optional typehint to KeypointMatchingOutput

* LightGluePR suggestions:
- use `post_process_keypoint_matching` as default docs example
- add `post_process_keypoint_matching` in autodoc
- add `SuperPointConfig` import under TYPE_CHECKING condition
- format SuperGlueConfig docstring
- add device in convert_superglue_to_hf
- Fix typo
- Fix KeypointMatchingOutput docstring
- Removed unnecessary line
- Added missing SuperGlueConfig in __init__ methods

* LightGluePR suggestions:
- use batching to get keypoint detection

* refactor: processing images done in 1 for loop instead of 4

* fix: use @ instead of torch.einsum for scores computation

* style: added #fmt skip to long tensor values

* refactor: rollbacked validate_and_format_image_pairs valid and invalid case to more simple ones

* refactor: prepare_imgs

* refactor: simplified `validate_and_format_image_pairs`

* docs: fixed doc

---------

Co-authored-by: steven <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Steven Bucaille <[email protected]>
Co-authored-by: Pavel Iakubovskii <[email protected]>
elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
* Initial commit with template code generated by transformers-cli

* Multiple additions to SuperGlue implementation :

- Added the SuperGlueConfig
- Added the SuperGlueModel and its implementation
- Added basic weight conversion script
- Added new ImageMatchingOutput dataclass

* Few changes for SuperGlue

* Multiple changes :
- Added keypoint detection config to SuperGlueConfig
- Completed convert_superglue_to_pytorch and succesfully run inference

* Reverted unintentional change

* Multiple changes :
 - Added SuperGlue to a bunch of places
 - Divided SuperGlue into SuperGlueForImageMatching and SuperGlueModel
 - Added testing images

* Moved things in init files

* Added docs (to be finished depending on the final implementation)

* Added necessary imports and some doc

* Removed unnecessary import

* Fixed make fix-copies bug and ran it

* Deleted SuperGlueModel
Fixed convert script

* Added SuperGlueImageProcessor

* Changed SuperGlue to support batching pairs of images and modified ImageMatchingOutput in consequences

* Changed convert_superglue_to_hf.py script to experiment different ways of reading an image and seeing its impact on performances

* Added initial tests for SuperGlueImageProcessor

* Added AutoModelForImageMatching in missing places and tests

* Fixed keypoint_detector_output instructions

* Fix style

* Adapted to latest main changes

* Added integration test

* Fixed bugs to pass tests

* Added keypoints returned by keypoint detector in the output of SuperGlue

* Added doc to SuperGlue

* SuperGlue returning all attention and hidden states for a fixed number of keypoints

* Make style

* Changed SuperGlueImageProcessor tests

* Revert "SuperGlue returning all attention and hidden states for a fixed number of keypoints"
Changed tests accordingly

This reverts commit 5b3b669c

* Added back hidden_states and attentions masked outputs with tests

* Renamed ImageMatching occurences into KeypointMatching

* Changed SuperGlueImageProcessor to raise error when batch_size is not even

* Added docs and clarity to hidden state and attention grouping function

* Fixed some code and done refactoring

* Fixed typo in SuperPoint output doc

* Fixed some of the formatting and variable naming problems

* Removed useless function call

* Removed AutoModelForKeypointMatching

* Fixed SuperGlueImageProcessor to only accept paris of images

* Added more fixes to SuperGlueImageProcessor

* Simplified the batching of attention and hidden states

* Simplified stack functions

* Moved attention instructions into class

* Removed unused do_batch_norm argument

* Moved weight initialization to the proper place

* Replaced deepcopy for instantiation

* Fixed small bug

* Changed from stevenbucaille to magic-leap repo

* Renamed London Bridge images to Tower Bridge

* Fixed formatting

* Renamed remaining "london" to "tower"

* Apply suggestions from code review

Small changes in the docs

Co-authored-by: amyeroberts <[email protected]>

* Added AutoModelForKeypointMatching

* Changed images used in example

* Several changes to image_processing_superglue and style

* Fixed resample type hint

* Changed SuperGlueImageProcessor and added test case for list of 2 images

* Changed list_of_tuples implementation

* Fix in dummy objects

* Added normalize_keypoint, log_sinkhorn_iterations and log_optimal_transport docstring

* Added missing docstring

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Moved forward block at bottom

* Added docstring to forward method

* Added docstring to match_image_pair method

* Changed test_model_common_attributes to test_model_get_set_embeddings test method signature

* Removed AutoModelForKeypointMatching

* Removed image fixtures and added load_dataset

* Added padding of images in SuperGlueImageProcessor

* Cleaned up convert_superglue_to_hf script

* Added missing docs and fixed unused argument

* Fixed SuperGlueImageProcessor tests

* Transposed all hidden states from SuperGlue to reflect the standard (..., seq_len, feature_dim) shape

* Added SuperGlueForKeypointMatching back to modeling_auto

* Fixed image processor padding test

* Changed SuperGlue docs

* changes:
 - Abstraction to batch, concat and stack of inconsistent tensors
 - Changed conv1d's to linears to match standard attention implementations
 - Renamed all tensors to be tensor0 and not tensor_0 and be consistent
 - Changed match image pair to run keypoint detection on all image first, create batching tensors and then filling these tensors matches after matches
 - Various changes in docs, etc

* Changes to SuperGlueImageProcessor:
- Reworked the input image pairs checking function and added tests accordingly
- Added Copied from statements
- Added do_grayscale tag (also for SuperPointImageProcessor)
- Misc changes for better code

* Formatting changes

* Reverted conv1d to linear conversion because of numerical differences

* fix: changed some code to be more straightforward (e.g. filtering keypoints) and converted plot from opencv to matplotlib

* fix: removed unnecessary test

* chore: removed commented code and added back hidden states transpositions

* chore: changed from "inconsistent" to "ragged" function names as suggested

Co-authored-by: amyeroberts <[email protected]>

* docs: applied suggestions

Co-authored-by: amyeroberts <[email protected]>

* docs: updated to display matched output

* chore: applied suggestion for check_image_pairs_input function

Co-authored-by: amyeroberts <[email protected]>

* chore: changed check_image_pairs_input function name to validate_and_format_image_pairs and used validate_preprocess_arguments function

* tests: simplified tests for image input format and shapes

* feat: converted SuperGlue's use of Conv1d with kernel_size of 1 with Linear layers. Changed tests and conversion script accordingly

* feat: several changes to address comments

Conversion script:
- Reverted fuse batchnorm to linear conversion
- Changed all 'nn.Module' to respective SuperGlue models
- Changed conversion script to use regex mapping and match other recent scripts

Modeling SuperGlue:
- Added batching with mask and padding to attention
- Removed unnecessary concat, stack and batch ragged pairs functions
- Reverted batchnorm layer
- Renamed query, key, value and merge layers into q, k, v, out proj
- Removed Union of different Module into nn.Module in _init_weights method typehint
- Changed several method's signature to combine image0 and image1 inputs with appropriate doc changes
- Updated SuperGlue's doc with torch.no_grad()

Updated test to reflect changes in SuperGlue model

* refactor: changed validate_and_format_image_pairs function with clarity

* refactor: changed from one SuperGlueMLP class to a list of SuperGlueMLP class

* fix: fixed forgotten init weight change from last commit

* fix: fixed rebase mistake

* fix: removed leftover commented code

* fix: added typehint and changed some of arguments default values

* fix: fixed attribute default values for SuperGlueConfig

* feat: added SuperGlueImageProcessor post process keypoint matching method with tests

* fix: fixed SuperGlue attention and hidden state tuples aggregation

* chore: fixed mask optionality and reordered tensor reshapes to be cleaner

* chore: fixed docs and error message returned in validate_and_format_image_pairs function

* fix: fixed returned keypoints to be the ones that SuperPoint returns

* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue

* fix: fixed check on number of image sizes for post process compared to the pairs in outputs of SuperGlue (bis)

* fix: Changed SuperGlueMultiLayerPerceptron instantiation to avoid if statement

* fix: Changed convert_superglue_to_hf script to reflect latest SuperGlue changes and got rid of nn.Modules

* WIP: implement Attention from an existing class (like BERT)

* docs: Changed docs to include more appealing matching plot

* WIP: Implement Attention

* chore: minor typehint change

* chore: changed convert superglue script by removing all classes and apply conv to linear conversion in state dict + rearrange keys to comply with changes in model's layers organisation

* Revert "Fixed typo in SuperPoint output doc"

This reverts commit 2120390.

* chore: added comments in SuperGlueImageProcessor

* chore: changed SuperGlue organization HF repo to magic-leap-community

* [run-slow] refactor: small change in layer instantiation

* [run-slow] chore: replaced remaining stevenbucaille org to magic-leap-community

* [run-slow] chore: make style

* chore: update image matching fixture dataset HF repository

* [run-slow] superglue

* tests: overwriting test_batching_equivalence

* [run-slow] superglue

* tests: changed test to cope with value changing depending on cuda version

* [run-slow] superglue

* tests: changed matching_threshold value

* [run-slow] superglue

* [run-slow] superglue

* tests: changed tests for integration

* [run-slow] superglue

* fix: Changed tensor view and permutations to match original implementation results

* fix: updated convert script and integration test to include last change in model

* fix: increase tolerance for CUDA variances

* Apply suggestions from code review

Co-authored-by: Pavel Iakubovskii <[email protected]>

* [run-slow] superglue

* chore: removed blank whitespaces

* [run-slow] superglue

* Revert SuperPoint image processor accident changes

* [run-slow] superglue

* refactor: reverted copy from BERT class

* tests: lower the tolerance in integration tests for SuperGlue

* [run-slow] superglue

* chore: set do_grayscale to False in SuperPoint and SuperGlue image processors

* [run-slow] superglue

* fix: fixed imports in SuperGlue files

* chore: changed do_grayscale SuperGlueImageProcessing default value to True

* docs: added typehint to post_process_keypoint_matching method in SuperGlueImageProcessor

* fix: set matching_threshold default value to 0.0 instead of 0.2

* feat: added matching_threshold to post_process_keypoint_matching method

* docs: update superglue.md to include matching_threshold parameter

* docs: updated SuperGlueConfig docstring for matching_threshold default value

* refactor: removed unnecessary parameters in SuperGlueConfig

* fix: changed from matching_threshold to threshold

* fix: re-revert changes to make SuperGlue attention classes copies of BERT

* [run-slow] superglue

* fix: added missing device argument in post_processing method

* [run-slow] superglue

* fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching (and docstring)

* fix: add device to image_sizes tensor instantiation

* tests: added checks on do_grayscale test

* chore: reordered and added Optional typehint to KeypointMatchingOutput

* LightGluePR suggestions:
- use `post_process_keypoint_matching` as default docs example
- add `post_process_keypoint_matching` in autodoc
- add `SuperPointConfig` import under TYPE_CHECKING condition
- format SuperGlueConfig docstring
- add device in convert_superglue_to_hf
- Fix typo
- Fix KeypointMatchingOutput docstring
- Removed unnecessary line
- Added missing SuperGlueConfig in __init__ methods

* LightGluePR suggestions:
- use batching to get keypoint detection

* refactor: processing images done in 1 for loop instead of 4

* fix: use @ instead of torch.einsum for scores computation

* style: added #fmt skip to long tensor values

* refactor: rollbacked validate_and_format_image_pairs valid and invalid case to more simple ones

* refactor: prepare_imgs

* refactor: simplified `validate_and_format_image_pairs`

* docs: fixed doc

---------

Co-authored-by: steven <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Steven Bucaille <[email protected]>
Co-authored-by: Pavel Iakubovskii <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement SuperPoint / SuperGlue
5 participants