Skip to content

Commit

Permalink
superpoint torch using simple nms
Browse files Browse the repository at this point in the history
  • Loading branch information
borongyuan committed Feb 3, 2024
1 parent b2a86d6 commit e71bb55
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions corelib/src/superpoint_torch/SuperPoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
*/

#include <superpoint_torch/SuperPoint.h>
#include <rtabmap/core/util2d.h>
#include <rtabmap/utilite/ULogger.h>
#include <rtabmap/utilite/UDirectory.h>
#include <rtabmap/utilite/UFile.h>
Expand Down Expand Up @@ -155,49 +154,50 @@ std::vector<cv::KeyPoint> SPDetector::detect(const cv::Mat &img, const cv::Mat &
x = x.set_requires_grad(false);
auto out = model_->forward(x.to(device));

prob_ = out[0].squeeze(0); // [H, W]
auto scores = out[0]; // [1, H, W]
desc_ = out[1]; // [1, 256, H/8, W/8]

if(nms_)
{
auto options = torch::nn::functional::MaxPool2dFuncOptions(minDistance_*2+1).stride(1).padding(minDistance_);
auto options_r1 = torch::nn::functional::MaxPool2dFuncOptions(3).stride(1).padding(1);

auto zeros = torch::zeros_like(scores);
auto max_mask = scores == torch::nn::functional::max_pool2d(scores, options);
auto max_mask_r1 = scores == torch::nn::functional::max_pool2d(scores, options_r1);
for(size_t i=0; i<2; i++)
{
auto supp_mask = torch::nn::functional::max_pool2d(max_mask.to(torch::kF32), options) > 0;
auto supp_scores = torch::where(supp_mask, zeros, scores);
auto new_max_mask = supp_scores == torch::nn::functional::max_pool2d(supp_scores, options);
max_mask = max_mask | (new_max_mask & (~supp_mask) & max_mask_r1);
}
prob_ = torch::where(max_mask, scores, zeros).squeeze(0);
}
else
{
prob_ = scores.squeeze(0);
}

auto kpts = (prob_ > threshold_);
kpts = torch::nonzero(kpts); // [n_keypoints, 2] (y, x)

//convert back to cpu if in gpu
auto kpts_cpu = kpts.to(torch::kCPU);
auto prob_cpu = prob_.to(torch::kCPU);

std::vector<cv::KeyPoint> keypoints_no_nms;
for (int i = 0; i < kpts_cpu.size(0); i++) {
std::vector<cv::KeyPoint> keypoints;
for(int i=0; i<kpts_cpu.size(0); i++)
{
if(mask.empty() || mask.at<unsigned char>(kpts_cpu[i][0].item<int>(), kpts_cpu[i][1].item<int>()) != 0)
{
float response = prob_cpu[kpts_cpu[i][0]][kpts_cpu[i][1]].item<float>();
keypoints_no_nms.push_back(cv::KeyPoint(kpts_cpu[i][1].item<float>(), kpts_cpu[i][0].item<float>(), 8, -1, response));
keypoints.emplace_back(cv::KeyPoint(kpts_cpu[i][1].item<float>(), kpts_cpu[i][0].item<float>(), 8, -1, response));
}
}

detected_ = true;
if (nms_ && !keypoints_no_nms.empty()) {
int border = 0;
int dist_thresh = minDistance_;
int height = img.rows;
int width = img.cols;

std::vector<cv::KeyPoint> keypoints;
cv::Mat descEmpty;
util2d::NMS(keypoints_no_nms, descEmpty, keypoints, descEmpty, border, dist_thresh, width, height);
if(keypoints.size()>1)
{
return keypoints;
}
return std::vector<cv::KeyPoint>();
}
else if(keypoints_no_nms.size()>1)
{
return keypoints_no_nms;
}
else
{
return std::vector<cv::KeyPoint>();
}
return keypoints;
}
else
{
Expand Down

0 comments on commit e71bb55

Please sign in to comment.