Skip to content

Commit

Permalink
Merge pull request #1 from ZHEQIUSHUI/update_1020
Browse files Browse the repository at this point in the history
optimize gui
  • Loading branch information
ZHEQIUSHUI authored Oct 20, 2023
2 parents 55af0ff + 8adad2d commit 4fc0a13
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 454 deletions.
1 change: 1 addition & 0 deletions qtproj/SAMQT/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CMakeLists.txt.user*
436 changes: 0 additions & 436 deletions qtproj/SAMQT/CMakeLists.txt.user

This file was deleted.

33 changes: 31 additions & 2 deletions qtproj/SAMQT/mainwindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "QMessageBox"
#include "QIntValidator"

MainWindow::MainWindow(std::string encoder_model_path,std::string decoder_model_path,std::string inpaint_model_path , QWidget *parent)
MainWindow::MainWindow(std::string encoder_model_path, std::string decoder_model_path, std::string inpaint_model_path, QWidget *parent)
: QMainWindow(parent), ui(new Ui::MainWindow)
{
ui->setupUi(this);
Expand Down Expand Up @@ -56,7 +56,7 @@ void MainWindow::on_btn_remove_obj_clicked()
dilate_size = 111;
if (dilate_size < 5)
dilate_size = 5;
this->ui->label->ShowRemoveObject(dilate_size, this->ui->progressBar_remove_obj);
this->ui->label->ShowRemoveObject(dilate_size, this->ui->progressBar_remove_obj, ui->ch_merge_mask->isChecked());
this->setEnabled(true);
}

Expand Down Expand Up @@ -107,3 +107,32 @@ void MainWindow::on_radioButton_box_clicked()
{
this->ui->label->SetBoxPrompt(this->ui->radioButton_box->isChecked());
}

void MainWindow::on_btn_save_img_clicked()
{
auto cur_image = this->ui->label->getCurrentImage();

QString filename = QFileDialog::getSaveFileName(this,
tr("Save Image"),
"",
tr("*.bmp;; *.png;; *.jpg")); // 选择路径
if (filename.isEmpty())
{
return;
}
else
{
if (!(filename.endsWith(".bmp") || filename.endsWith(".png") || filename.endsWith(".jpg")))
{
filename += ".png";
}

if (!(cur_image.save(filename))) // 保存图像
{
QMessageBox::information(this,
tr("Failed to save the image"),
tr("Failed to save the image!"));
return;
}
}
}
2 changes: 2 additions & 0 deletions qtproj/SAMQT/mainwindow.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ private slots:

void on_radioButton_box_clicked();

void on_btn_save_img_clicked();

private:
Ui::MainWindow *ui;
};
Expand Down
14 changes: 14 additions & 0 deletions qtproj/SAMQT/mainwindow.ui
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@
</item>
</layout>
</item>
<item>
<widget class="QCheckBox" name="ch_merge_mask">
<property name="text">
<string>MergeMask</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="btn_remove_obj">
<property name="enabled">
Expand Down Expand Up @@ -134,6 +141,13 @@
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="btn_save_img">
<property name="text">
<string>SaveImage</string>
</property>
</widget>
</item>
</layout>
</item>
<item>
Expand Down
68 changes: 52 additions & 16 deletions qtproj/SAMQT/myqlabel.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class myQLabel : public QLabel
bool mouseHolding = false;
QPoint pt_img_first, pt_img_secend;
SAM mSam;
// LamaInpaintOnnx mInpaint;
// LamaInpaintOnnx mInpaint;
std::shared_ptr<LamaInpaint> mInpaint;

void dragEnterEvent(QDragEnterEvent *event) override
Expand Down Expand Up @@ -181,6 +181,11 @@ class myQLabel : public QLabel
{
}

QImage getCurrentImage()
{
return cur_image;
}

void SetImage(QImage img)
{
cur_masks.clear();
Expand Down Expand Up @@ -216,7 +221,7 @@ class myQLabel : public QLabel
void InitModel(std::string encoder_model, std::string decoder_model, std::string inpaint_model)
{
mSam.Load(encoder_model, decoder_model);
// mInpaint.Load(inpaint_model);
// mInpaint.Load(inpaint_model);

if (string_utility<std::string>::ends_with(inpaint_model, ".onnx"))
{
Expand All @@ -233,7 +238,7 @@ class myQLabel : public QLabel
mInpaint->Load(inpaint_model);
}

void ShowRemoveObject(int dilate_size, QProgressBar *bar)
void ShowRemoveObject(int dilate_size, QProgressBar *bar, bool remove_mask_by_merge = true)
{
if (!cur_image.bits() || !grab_masks.size())
{
Expand All @@ -257,21 +262,52 @@ class myQLabel : public QLabel
bar->setMinimum(0);
bar->setMaximum(grab_masks.size());
}
for (auto grab_mask : grab_masks)
if (remove_mask_by_merge)
{
auto time_start = std::chrono::high_resolution_clock::now();
inpainted = mInpaint->Inpaint(inpainted, grab_mask, dilate_size);
auto time_end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = time_end - time_start;
std::cout << "Inpaint Inference Cost time : " << diff.count() << "s" << std::endl;
QImage qinpainted(inpainted.data, inpainted.cols, inpainted.rows, inpainted.step1(), QImage::Format_BGR888);
cur_image = qinpainted.copy();
if (cur_masks.size())
cur_masks.removeFirst();
repaint();
if (bar)
bar->setValue(bar->value() + 1);
if (grab_masks.size())
{
auto base_mask = grab_masks[0];
if (bar)
bar->setValue(bar->value() + 1);

// merge all mask
for (size_t i = 1; i < grab_masks.size(); i++)
{
base_mask |= grab_masks[i];
if (bar)
bar->setValue(bar->value() + 1);
}

auto time_start = std::chrono::high_resolution_clock::now();
inpainted = mInpaint->Inpaint(inpainted, base_mask, dilate_size);
auto time_end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = time_end - time_start;
std::cout << "Inpaint Inference Cost time : " << diff.count() << "s" << std::endl;
QImage qinpainted(inpainted.data, inpainted.cols, inpainted.rows, inpainted.step1(), QImage::Format_BGR888);
cur_image = qinpainted.copy();
cur_masks.clear();
repaint();
}
}
else
{
for (auto grab_mask : grab_masks)
{
auto time_start = std::chrono::high_resolution_clock::now();
inpainted = mInpaint->Inpaint(inpainted, grab_mask, dilate_size);
auto time_end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = time_end - time_start;
std::cout << "Inpaint Inference Cost time : " << diff.count() << "s" << std::endl;
QImage qinpainted(inpainted.data, inpainted.cols, inpainted.rows, inpainted.step1(), QImage::Format_BGR888);
cur_image = qinpainted.copy();
if (cur_masks.size())
cur_masks.removeFirst();
repaint();
if (bar)
bar->setValue(bar->value() + 1);
}
}

cur_masks.clear();
rgba_masks.clear();
grab_masks.clear();
Expand Down

0 comments on commit 4fc0a13

Please sign in to comment.