Skip to content
This repository has been archived by the owner on Nov 21, 2023. It is now read-only.

Detectron in C++ #199

Open
marcelogarcia94 opened this issue Feb 27, 2018 · 32 comments
Open

Detectron in C++ #199

marcelogarcia94 opened this issue Feb 27, 2018 · 32 comments

Comments

@marcelogarcia94
Copy link

Hello,
I would like to use Detectron in C++. Is there any way to do it?

Thanks.

@orionr
Copy link

orionr commented Feb 27, 2018

Some of Detectron requires Python ops to function. However, there is support to export a limited number of models to straight Caffe2 C++ ops at

https://github.com/facebookresearch/Detectron/blob/master/tools/convert_pkl_to_pb.py

These models are also available in the Caffe2 model zoo at

https://github.com/caffe2/models/tree/master/detectron

Hope that helps.

@dbrazey
Copy link

dbrazey commented Feb 28, 2018

I would also like to run a detectron CNN in a C++ code (only for inference).
I managed to compile and run a simple CNN using the caffe2 framework for classification.
Now, I try to run a CNN from detectron to do object detection.

I tried to use my classification code to load and run a net downloaded from https://github.com/caffe2/models/tree/master/detectron.
1 - However, I get the following error when I run the net :

> terminate called after throwing an instance of 'caffe2::EnforceNotMet'
>   what():  [enforce fail at generate_proposals_op.cc:205] im_info_tensor.dims() == (vector<TIndex>{num_images, 3}).  vs  Error from operator: 
> input: "rpn_cls_probs" input: "rpn_bbox_pred" input: "im_info" input: "anchor" output: "rpn_rois" output: "rpn_roi_probs" name: "" type: "GenerateProposals" arg { name: "nms_thres" f: 0.7 } arg { name: "min_size" f: 0 } arg { name: "spatial_scale" f: 0.0625 } arg { name: "correct_transform_coords" i: 1 } arg { name: "post_nms_topN" i: 1000 } arg { name: "pre_nms_topN" i: 6000 }
> *** Aborted at 1519832615 (unix time) try "date -d @1519832615" if you are using GNU date ***
> PC: @     0x7f80d5190067 (unknown)
> *** SIGABRT (@0x3e800002ef1) received by PID 12017 (TID 0x7f80db701f00) from PID 12017; stack trace: ***
>     @     0x7f80d51900e0 (unknown)
>     @     0x7f80d5190067 (unknown)
>     @     0x7f80d5191448 (unknown)
>     @     0x7f80d5a7db3d (unknown)
>     @     0x7f80d5a7bbb6 (unknown)
>     @     0x7f80d5a7bc01 (unknown)
>     @     0x7f80d5a7be69 (unknown)
>     @     0x7f80d845cee8 caffe2::Operator<>::Run()
>     @     0x7f80d76f5627 caffe2::SimpleNet::Run()
>     @     0x7f80d76afcda caffe2::Workspace::RunNetOnce()
>     @           0x431f82 ClassifieurCNN2::Process()
>     @           0x422075 caffe2::customclass()
>     @           0x42217b main
>     @     0x7f80d517cb45 (unknown)
>     @           0x41efc9 (unknown)
>     @                0x0 (unknown)

It seems to be a tensor size problem.

2 - In py-faster rcnn (caffe 1 + custom python layers), we needed to create two blobs "data" and "im_info".
In detectron, must I also define an extra tensor ?

3 - Did I miss any example program written in C++ for inference with detectron ?
Do you know which detection algorithms can be run using only C++ ops ?

Thanks for your great work :)

@dbrazey
Copy link

dbrazey commented Mar 1, 2018

I solved problems 1 and 2.
I have used "generate_proposals_op_test.cc" as an example.

@signal926
Copy link

signal926 commented Apr 18, 2018

@dbrazey Hi,I'm working on writing a C++ code to run Detectron model. I've met with the similar problem with your problem 1.
I got the errors:
[enforce fail at generate_proposals_op.cc:221] im_info_tensor.dims() == (vector<TIndex>{num_images, 3}). vs Error from operator: input: "rpn_cls_probs_cpu" input: "rpn_bbox_pred_cpu" input: "im_info" input: "anchor" output: "rpn_rois" output: "rpn_roi_probs" name: "" type: "GenerateProposals" arg { name: "nms_thres" f: 0.7 } arg { name: "min_size" f: 0 } arg { name: "spatial_scale" f: 0.0625 } arg { name: "correct_transform_coords" i: 1 } arg { name: "post_nms_topN" i: 1000 } arg { name: "pre_nms_topN" i: 6000 }

Can you give me some suggestions about this problem? Thanks a lot.

@lupotto
Copy link

lupotto commented Apr 20, 2018

Hello,

I did the same process as @dbrazey, first I played with CNN for classification but right now I am facing the same problem that you commented. I tried to load the e2e_faster_rcnn_R-50-C4_2x in Caffe2 but this error pops up:

RuntimeError: [enforce fail at generate_proposals_op.cc:205] im_info_tensor.dims() == (vector{num_images, 3}). vs Error from operator:
input: "rpn_cls_probs" input: "rpn_bbox_pred" input: "im_info" input: "anchor" output: "rpn_rois" output: "rpn_roi_probs" name: "" type: "GenerateProposals" arg { name: "nms_thres" f: 0.7 } arg { nam

I tried to look the "generate_proposals_op_test" but I couldn't work it out. It would be so helpful if you can give me some suggestions about the problem. Thank you!

@dbrazey
Copy link

dbrazey commented Apr 25, 2018

I think you get the error because you didn t create the "im_info" blob (or with a wrong size).
This blob must be created in your workspace, with the usual "data" blob.

In "generate_proposals_op_test", you have :

 def.add_input("im_info");
// some stuff 
AddConstInput(vector<TIndex>{img_count, 3}, 0.1, "im_info", &ws);

which means

m_buffer_im_info = new float[3];
m_buffer_im_info[0] = 800; // width
m_buffer_im_info[1] = 800; // height
m_buffer_im_info[2] = 1; // don t remember ... see py-faster rcnn c++ examples  

std::vector<caffe2::TIndex> shape{(int) 1, (int) 3}; // 1 image, 3 values 
auto tensor_im_info = m_workspace->CreateBlob("im_info")->GetMutable<caffe2::TensorCPU>();
tensor_im_info->Resize(shape);
tensor_im_info->ShareExternalPointer(m_buffer_im_info);

// im_info is now initialized with correct values 

@dongmingsun
Copy link

Hi there, I wonder if the link https://github.com/facebookresearch/Detectron/blob/master/tools/convert_pkl_to_pb.py
can be applied to Mask R-CNN?
I'd like to use the trained model in my C++ code, but no sure if it's feasible to do so.

@gaqiness
Copy link

+1

@SerinaWei
Copy link

@dbrazey Can you please provide a code snippet on how to run the object detection net (e.g. using e2e_faster_rcnn_R-50-C4_1x or similar networks). I was able to convert the models to .pb, and also provide "im_info" and "data" blobs to the workspace, but I still cannot run the network. Thanks!

@gadcam
Copy link
Contributor

gadcam commented May 23, 2018

@SerinaWei You have an example here of how to run it https://github.com/daquexian/Detectron/blob/27f5aee53785af99147a634344d58a70bfbd250e/tools/convert_pkl_to_pb.py#L501-L549

I wonder if the link https://github.com/facebookresearch/Detectron/blob/master/tools/convert_pkl_to_pb.py can be applied to Mask R-CNN?

@dongmingsun Not yet but I achieved it I will do a PR as soon as my code is clean enough. But it will not help alone to train/run the model in C++ you would need to implement a few more things.

@SerinaWei
Copy link

@gadcam I was able to use the .pb in python and would like to run the model in C++. I have set the "im_info" and "data" blobs. What else do I need to implement to run? If you can share your insight on this, I really appreciate it!

@gadcam
Copy link
Contributor

gadcam commented May 23, 2018

@SerinaWei Sorry I read it too fast and did not check that you were speaking about C++.
I just read in #432 (comment) that https://github.com/leonardvandriel/caffe2_cpp_tutorial could be useful.

@SerinaWei
Copy link

@gadcam I read leonard's tutorial, and it helps in classification in C++, but it doesn't look like it has any examples to use detection net.

@SerinaWei
Copy link

@rbgirshick Can you please provide insight/examples on how to use detection models (e2e_faster_rcnn_R-50-C4_1x or similar networks) in C++ in Detectron? I've been struggling for a while. Thanks!

@HappyKerry
Copy link

@SerinaWei how to provide "im_info" and "data" blobs to the workspace? Thanks

@daquexian
Copy link
Contributor

daquexian commented May 24, 2018

@SerinaWei @HappyKerry

Some code snippet. Actually it's not hard to find similar code in tutorials (however frankly speaking caffe2 tutorials is so rare)

        auto inputBlobs = getInputBlobFromFile(FLAGS_file);

        caffe2::Blob* blob = workspace->GetBlob("data");
        auto *tensor = blob->GetMutable<caffe2::TensorCPU>();
        tensor->ResizeLike(inputBlobs.first);
        tensor->ShareData(inputBlobs.first);

        blob = workspace->GetBlob("im_info");
        tensor = blob->GetMutable<caffe2::TensorCPU>();
        tensor->ResizeLike(inputBlobs.second);
        tensor->ShareData(inputBlobs.second);

        net->Run();
        blob = workspace->GetBlob("score_nms");
        auto score_tensor = blob->Get<caffe2::TensorCPU>();
        blob = workspace->GetBlob("class_nms");
        auto class_tensor = blob->Get<caffe2::TensorCPU>();
        blob = workspace->GetBlob("bbox_nms");
        auto bbox_tensor = blob->Get<caffe2::TensorCPU>();

        for (int i = 0; i < bbox_tensor.dim(0); i++) {
            LOG(INFO) << "class: " << class_tensor.data<float>()[i];
            LOG(INFO) << "score: " << score_tensor.data<float>()[i];
            LOG(INFO) << "bbox: "
                      << bbox_tensor.data<float>()[i * 4] << ", "
                      << bbox_tensor.data<float>()[i * 4 + 1] << ", "
                      << bbox_tensor.data<float>()[i * 4 + 2] << ", "
                      << bbox_tensor.data<float>()[i * 4 + 3];
        }

Use all-zero fake input:

pair<caffe2::TensorCPU, caffe2::TensorCPU> getInputBlobFromFile(const string& filename) {

    vector<caffe2::TIndex> img_dims({1, 3, SCALES, SCALES});
    caffe2::TensorCPU t(img_dims);
    memset(t.mutable_data<float>(), 0, 3 * SCALES * SCALES * sizeof(float));

    vector<caffe2::TIndex> im_info_dims({1, 3});
    caffe2::TensorCPU im_info(im_info_dims);
    im_info.mutable_data<float>()[0] = SCALES;
    im_info.mutable_data<float>()[1] = SCALES;
    im_info.mutable_data<float>()[2] = 1;

    return pair<caffe2::TensorCPU, caffe2::TensorCPU>(t, im_info);
}

and if you want to read image as input (I use opencv to read image. SCALES, MAX_SIZE and FPN_COARSE_STRIDE have the same meaning as those in detectron's config.py, you can check out config.py for document):

pair<caffe2::TensorCPU, caffe2::TensorCPU> getInputBlobFromFile(const string& filename) {
    auto img = cv::imread(filename);

    int imgHeight = img.rows;
    int imgWidth = img.cols;

    int minSize = img.rows < img.cols ? img.rows : img.cols;
    int maxSize = img.rows < img.cols ? img.cols : img.rows;

    double scale = 1. * SCALES / minSize;
    if (maxSize * scale > MAX_SIZE) {
        scale = 1. * MAX_SIZE / maxSize;
    }

    resize(img, img, cv::Size(0, 0), scale, scale);
    int padHeight = (img.rows + FPN_COARSE_STRIDE - 1) / FPN_COARSE_STRIDE * FPN_COARSE_STRIDE;
    int padWidth = (img.cols + FPN_COARSE_STRIDE - 1) / FPN_COARSE_STRIDE * FPN_COARSE_STRIDE;
    Mat imgPadded(padHeight, padWidth, img.type());
    cv::copyMakeBorder(img, imgPadded, 0, padHeight - img.rows, 0, padWidth - img.cols, cv::BORDER_CONSTANT, 0);
    imgPadded.convertTo(imgPadded, CV_32FC3);
    imgPadded -= cv::Scalar(102.9801, 115.9465, 122.771);
    vector<Mat> channels(3);
    split(imgPadded, channels);
    vector<float> data;
    for (auto &c : channels) {
        data.insert(data.end(), (float *)c.datastart, (float *)c.dataend);
    }
    vector<caffe2::TIndex> img_dims({1, 3, imgPadded.rows, imgPadded.cols});
    caffe2::TensorCPU t(img_dims, data, nullptr);

    vector<caffe2::TIndex> im_info_dims({1, 3});
    caffe2::TensorCPU im_info(im_info_dims);
    im_info.mutable_data<float>()[0] = imgHeight;
    im_info.mutable_data<float>()[1] = imgWidth;
    im_info.mutable_data<float>()[2] = static_cast<float>(scale);

    return pair<caffe2::TensorCPU, caffe2::TensorCPU>(t, im_info);
}

@HappyKerry
Copy link

@daquexian Thanks

@HappyKerry
Copy link

HappyKerry commented May 24, 2018

@daquexian
I am a newer on caffe2 and detectron, I changed the pkl model to pb model follow your steps, and I want to use the following code to detect on GPU,
void run() {
NetDef init_net_def, predict_net_def;
DeviceOption option;
option.set_device_type(caffe2::CUDA);
caffe2::CUDAContext context;
caffe2::CUDAContext *pcontext=&context;

pcontext = new caffe2::CUDAContext(option);
init_net_def.mutable_device_option()->CopyFrom(option);
predict_net_def.mutable_device_option()->CopyFrom(option);

unique_ptr predict_net;

cout<<"pos1"<<endl;
CAFFE_ENFORCE(ReadProtoFromFile("Detect_init.pb", &init_net_def));
CAFFE_ENFORCE(ReadProtoFromFile("Detect.pb", &predict_net_def));

init_net_def.mutable_device_option()->set_device_type(CUDA);
predict_net_def.mutable_device_option()->set_device_type(CUDA);

    Workspace workspace2;
    Workspace *workspace=&workspace2;

 cout<<"pos2"<<endl;


workspace->RunNetOnce(init_net_def);

cout<<"pos3"<<endl;

predict_net = CreateNet(predict_net_def,workspace);

cout<<"pos4"<<endl;

    auto inputBlobs = getInputBlobFromFile("1.jpg");


    caffe2::Blob* blob = workspace->GetBlob("data");
    auto *tensor = blob->GetMutable<caffe2::TensorCPU>();
    tensor->ResizeLike(inputBlobs.first);
    tensor->ShareData(inputBlobs.first);

    blob = workspace->GetBlob("im_info");
    tensor = blob->GetMutable<caffe2::TensorCPU>();
    tensor->ResizeLike(inputBlobs.second);
    tensor->ShareData(inputBlobs.second);

    predict_net->Run();
    blob = workspace->GetBlob("score_nms");
    auto score_tensor = blob->Get<caffe2::TensorCPU>();
    blob = workspace->GetBlob("class_nms");
    auto class_tensor = blob->Get<caffe2::TensorCPU>();
    blob = workspace->GetBlob("bbox_nms");
    auto bbox_tensor = blob->Get<caffe2::TensorCPU>();

    for (int i = 0; i < bbox_tensor.dim(0); i++) {
        LOG(INFO) << "class: " << class_tensor.data<float>()[i];
        LOG(INFO) << "score: " << score_tensor.data<float>()[i];
        LOG(INFO) << "bbox: "
                  << bbox_tensor.data<float>()[i * 4] << ", "
                  << bbox_tensor.data<float>()[i * 4 + 1] << ", "
                  << bbox_tensor.data<float>()[i * 4 + 2] << ", "
                  << bbox_tensor.data<float>()[i * 4 + 3];
    }

}

but errors happened after called
predict_net = CreateNet(predict_net_def,workspace);

terminate called after throwing an instance of 'caffe2::EnforceNotMet'
what(): [enforce fail at operator.cc:185] op. Cannot create operator of type 'GenerateProposals' on the device 'CUDA'. Verify that implementation for the corresponding device exist. It might also happen if the binary is not linked with the operator implementation code. If Python frontend is used it might happen if dyndep.InitOpsLibrary call is missing. Operator def: input: "rpn_cls_probs_fpn2_cpu" input: "rpn_bbox_pred_fpn2_cpu" input: "im_info" input: "anchor2_cpu" output: "rpn_rois_fpn2" output: "rpn_roi_probs_fpn2" name: "" type: "GenerateProposals" arg { name: "nms_thres" f: 0.7 } arg { name: "min_size" f: 0 } arg { name: "spatial_scale" f: 0.25 } arg { name: "correct_transform_coords" i: 1 } arg { name: "post_nms_topN" i: 1000 } arg { name: "pre_nms_topN" i: 1000 } device_option { device_type: 1 }

Thanks

@SerinaWei
Copy link

@daquexian Thank you so much for the code snippet! I was able to get it to work with the all-zero fake input! I think the key is the padding. I didn't do padding before.

@HappyKerry
Copy link

@SerinaWei using cpu ? Can you please provide examples on how to use ? thanks

@SerinaWei
Copy link

@HappyKerry Yes, I am using CPU. Here is my code snippet (almost identical to daquxian's). I used the all-zero input to test the funcionality.

===================================================================

caffe2::NetDef _initNet, _predictNet;
caffe2::Workspace ws;

ws.RunNetOnce(_initNet);
caffe2::NetBase* net = ws.CreateNet(_predictNet);

auto inputBlobs = getInputBlobFromFile("");

caffe2::Blob* blob = ws.GetBlob("data");
auto *tensor = blob->GetMutable<caffe2::TensorCPU>();
tensor->ResizeLike(inputBlobs.first);
tensor->ShareData(inputBlobs.first);

blob = ws.GetBlob("im_info");
tensor = blob->GetMutable<caffe2::TensorCPU>();
tensor->ResizeLike(inputBlobs.second);
tensor->ShareData(inputBlobs.second);

net->Run();

@HappyKerry
Copy link

@SerinaWei @daquexian @orionr
In cpu mode I met the following problem, why it is so hard to use detecron in C++

terminate called after throwing an instance of 'caffe2::EnforceNotMet'
what(): [enforce fail at operator.cc:185] op. Cannot create operator of type 'BatchPermutation' on the device 'CPU'. Verify that implementation for the corresponding device exist. It might also happen if the binary is not linked with the operator implementation code. If Python frontend is used it might happen if dyndep.InitOpsLibrary call is missing. Operator def: input: "roi_feat_shuffled" input: "rois_idx_restore_int32" output: "roi_feat" name: "" type: "BatchPermutation" device_option { } engine: ""

@daquexian
Copy link
Contributor

daquexian commented May 25, 2018

@HappyKerry It seems that BatchPermutation only have cpu implementation. You can register a cuda operator using GPUFallbackOp by yourself like https://github.com/pytorch/pytorch/blob/master/caffe2/sgd/lars_op_gpu.cu

@SerinaWei
Copy link

@daquexian Is the maximum size you can scale to on Android for detection is smaller than 1333? Thanks so much for your help!

@daquexian
Copy link
Contributor

daquexian commented May 26, 2018 via email

@SerinaWei
Copy link

@daquexian Good to know! I had to use a size smaller than 1333 as well. Otherwise it hangs there.

@ferasboulala
Copy link

I think I am starting to lose my mind over this but I have trying countless times to run the caffe2's detectron model found in the model zoo in C++ with GPU but I keep getting the same error:

terminate called after throwing an instance of 'caffe2::EnforceNotMet' what(): [enforce fail at operator.cc:187] op. Cannot create operator of type 'GenerateProposals' on the device 'CUDA'. Verify that implementation for the corresponding device exist. It might also happen if the binary is not linked with the operator implementation code. If Python frontend is used it might happen if dyndep.InitOpsLibrary call is missing. Operator def: input: "rpn_cls_probs" input: "rpn_bbox_pred" input: "im_info" input: "anchor" output: "rpn_rois" output: "rpn_roi_probs" name: "" type: "GenerateProposals" arg { name: "nms_thres" f: 0.7 } arg { name: "min_size" f: 0 } arg { name: "spatial_scale" f: 0.0625 } arg { name: "correct_transform_coords" i: 1 } arg { name: "post_nms_topN" i: 1000 } arg { name: "pre_nms_topN" i: 6000 } device_option { device_type: 1 }

If I try it on an FPN model, I get a different operation in the error message (like BBoxTransform). I have linked everything in my CMakeLists (caffe2, caffe2_gpu, detectron_ops_gpu and cuda libs). Here's even a code snippet :

  caffe2::DeviceOption option;
  option.set_device_type(caffe2::CUDA);
  option.set_cuda_gpu_id(0);
  new caffe2::CUDAContext(option);

  caffe2::NetDef init_model, predict_model;
  CAFFE_ENFORCE(ReadProtoFromFile("init_net.pb", &init_model));
  CAFFE_ENFORCE(ReadProtoFromFile("predict_net.pb", &predict_model));
  init_model.mutable_device_option()->set_device_type(caffe2::CUDA); 
  predict_model.mutable_device_option()->set_device_type(caffe2::CUDA); 

  caffe2::Workspace workspace("tmp");
  workspace.RunNetOnce(init_model);
  caffe2::NetBase* net = workspace.CreateNet(predict_model); // line where it fails

Everything works if I run it in CPU but it takes about 20 seconds of inference time on a i7 quad core @ 3.4 GHz and so I would like to run it on a NVIDIA GTX 1080 Ti.

Does anybody have a clue on what is going on here ?

@gadcam
Copy link
Contributor

gadcam commented Aug 16, 2018

@ferasboulala how did you get these pb files ?
If you exported them for CPU I would not be that confident that they will work on GPU, I would try with pb files for GPU before anything else.

@ferasboulala
Copy link

I exported them with the --device gpu option.

@ferasboulala
Copy link

@gadcam GenerateProposals seems like it does not have any CUDA implementation in pure caffe2. Could it be that?

If anyone could share an example working code (and potentially CMakeLists while we are at it) with GPU in C++, that'd be amazing.

@Vince6121
Copy link

@ferasboulala same problem here ... After many researches, i don't find any C++ CUDA examples ... Are we the only ones which use c++ ?

@Yangjayhui
Copy link

@ferasboulala In cpu is ok,but in gpu it happened after called CreateNet(predict_net_def);

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests