Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG in torch1.11]THCDeviceUtils.cuh在pytorch1.11版本已经弃用,编译nms_rotated 失败 (yolov5_obb/utils/nms_rotated/src/poly_nms_cuda.cu:5:10: fatal error: THC/THCDeviceUtils.cuh: No such file or directory) #151

Open
QingYuan-L opened this issue Jan 11, 2022 · 15 comments

Comments

@QingYuan-L
Copy link

QingYuan-L commented Jan 11, 2022

参考 pytorch/pytorch#65472

@QingYuan-L
Copy link
Author

yolov5_obb/utils/nms_rotated/src/poly_nms_cuda.cu:5:10: fatal error: THC/THCDeviceUtils.cuh: No such file or directory

@QingYuan-L QingYuan-L changed the title [BUG]THCDeviceUtils.cuh在pytorch1.11版本已经弃用,编译nms [BUG]THCDeviceUtils.cuh在pytorch1.11版本已经弃用,编译nms_rotated 失败 Jan 11, 2022
@hukaixuan19970627 hukaixuan19970627 changed the title [BUG]THCDeviceUtils.cuh在pytorch1.11版本已经弃用,编译nms_rotated 失败 [BUG]THCDeviceUtils.cuh在pytorch1.11版本已经弃用,编译nms_rotated 失败(yolov5_obb/utils/nms_rotated/src/poly_nms_cuda.cu:5:10: fatal error: THC/THCDeviceUtils.cuh: No such file or directory) Jan 11, 2022
@hukaixuan19970627 hukaixuan19970627 changed the title [BUG]THCDeviceUtils.cuh在pytorch1.11版本已经弃用,编译nms_rotated 失败(yolov5_obb/utils/nms_rotated/src/poly_nms_cuda.cu:5:10: fatal error: THC/THCDeviceUtils.cuh: No such file or directory) [BUG]THCDeviceUtils.cuh在pytorch1.11版本已经弃用,编译nms_rotated 失败 (yolov5_obb/utils/nms_rotated/src/poly_nms_cuda.cu:5:10: fatal error: THC/THCDeviceUtils.cuh: No such file or directory) Jan 11, 2022
@hukaixuan19970627
Copy link
Owner

emm, 没改成功,按照PR说的改,又是各种缺少头文件的错误,我还是不改了,换torch稳定版是可以编译成功的

@QingYuan-L
Copy link
Author

we can wait for pytoch1.11(stable) to test this pr, if you use the newest NGC image, maybe the pr work. my dev environment : cuda11.5 pytorch1.11 ,nvcr.io/nvidia/pytorch:21.12-py3 from NGC

@Wangfeng2394
Copy link

we can wait for pytoch1.11(stable) to test this pr, if you use the newest NGC image, maybe the pr work. my dev environment : cuda11.5 pytorch1.11 ,nvcr.io/nvidia/pytorch:21.12-py3 from NGC

说的啥,看不懂,都中国人说中文就行

@hukaixuan19970627 hukaixuan19970627 changed the title [BUG]THCDeviceUtils.cuh在pytorch1.11版本已经弃用,编译nms_rotated 失败 (yolov5_obb/utils/nms_rotated/src/poly_nms_cuda.cu:5:10: fatal error: THC/THCDeviceUtils.cuh: No such file or directory) [BUG in torch1.11]THCDeviceUtils.cuh在pytorch1.11版本已经弃用,编译nms_rotated 失败 (yolov5_obb/utils/nms_rotated/src/poly_nms_cuda.cu:5:10: fatal error: THC/THCDeviceUtils.cuh: No such file or directory) Jan 11, 2022
@hukaixuan19970627
Copy link
Owner

#152

@QingYuan-L
Copy link
Author

现在公司新项目全都统一使用的英伟达最新的官方镜像(NGC),nvcr.io/nvidia/pytorch:21.12-py3,pytorch版本1.11.0a0+b6df043,现在不清楚pytorch1.11的稳定版会不会真的把THC丢掉,毕竟牵扯的项目有点多。

@hukaixuan19970627
Copy link
Owner

现在公司新项目全都统一使用的英伟达最新的官方镜像(NGC),nvcr.io/nvidia/pytorch:21.12-py3,pytorch版本1.11.0a0+b6df043,现在不清楚pytorch1.11的稳定版会不会真的把THC丢掉,毕竟牵扯的项目有点多。

应该不至于一个版本全丢掉,会有一个过渡版本输出warning信息来提示更换API的

@Pol1000
Copy link

Pol1000 commented May 30, 2022

Hi,
i've the same error

/home/ubuntu/PSD/utils/nms_rotated/src/poly_nms_cuda.cu:5:10: fatal error: THC/THC.h: No such file or directory
 #include <THC/THC.h>
          ^~~~~~~~~~~

there's news?
i'm using :
torch=1.11.0+cu115
cuda 11.5

thanks.

@hukaixuan19970627
Copy link
Owner

@Pol1000 #152

@LUO77123
Copy link

torch1.12出来了,1.11也只有1.11.0版本,是现在限定torch的版本到1.11以下吗?

@caolianxue
Copy link

reference to https://blog.csdn.net/weixin_41868417/article/details/123819183 and resolved

@shotyme
Copy link

shotyme commented Jan 10, 2023

My solution to compile successfully was to downgrade torch and change the requirements.txt to:

torch==1.9.1
torchvision==0.10.1

@whiterAutumn
Copy link

if you are in this situation. Windows compiles successfully, but there is a problem in training. The loss does not drop. The same data, models and parameters can be trained normally under Ubuntu.

Epoch gpu_mem box obj cls theta labels img_size

52/2999 5.44G 0.1232 0.3267 0.02696 0.7272 33 1280: 100%|██████████| 14/14 [00:17<00:00, 1.26s/it]
Class Images Labels P R [email protected] [email protected]:.95: 100%|██████████| 7/7 [00:04<00:00, 1.63it/s]
all 53 0 0 0 0 0

Epoch gpu_mem box obj cls theta labels img_size
53/2999 5.44G 0.1231 0.3131 0.02698 0.7273 33 1280: 100%|██████████| 14/14 [00:17<00:00, 1.27s/it]
Class Images Labels P R [email protected] [email protected]:.95: 100%|██████████| 7/7 [00:04<00:00, 1.59it/s]
all 53 0 0 0 0 0

Epoch gpu_mem box obj cls theta labels img_size
54/2999 5.44G 0.1232 0.3001 0.02696 0.7272 21 1280: 100%|██████████| 14/14 [00:17<00:00, 1.28s/it]
Class Images Labels P R [email protected] [email protected]:.95: 100%|██████████| 7/7 [00:04<00:00, 1.63it/s]
all 53 0 0 0 0 0

@regainOWO
Copy link

regainOWO commented Apr 20, 2023

@hukaixuan19970627 我觉得可以多添加一个修改后的文件来解决这个问题,针对不同的pytorch版本编译不同的文件,我的对setup.py做了一些修改,函数make_cuda_ext多添加了一个参数sources_cuda_later会根据不同的pytorch版本,编译不同的文件。

#!/usr/bin/env python
import os
from setuptools import setup

import torch
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
                                       CUDAExtension)


def make_cuda_ext(name, module, sources, sources_cuda=[], sources_cuda_later=[]):
    define_macros = []
    extra_compile_args = {'cxx': []}

    if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
        define_macros += [('WITH_CUDA', None)]
        extension = CUDAExtension
        extra_compile_args['nvcc'] = [
            '-D__CUDA_NO_HALF_OPERATORS__',
            '-D__CUDA_NO_HALF_CONVERSIONS__',
            '-D__CUDA_NO_HALF2_OPERATORS__',
        ]
        if torch.__version__ < '1.11' or len(sources_cuda_later) == 0:
            sources += sources_cuda
        else:
            sources += sources_cuda_later
    else:
        print(f'Compiling {name} without CUDA')
        extension = CppExtension
        # raise EnvironmentError('CUDA is required to compile MMDetection!')

    return extension(
        name=f'{module}.{name}',
        sources=[os.path.join(*module.split('.'), p) for p in sources],
        define_macros=define_macros,
        extra_compile_args=extra_compile_args)


# python setup.py develop
if __name__ == '__main__':
    # write_version_py()
    setup(
        name='nms_rotated',
        ext_modules=[
            make_cuda_ext(
                name='nms_rotated_ext',
                module='',
                sources=[
                    'src/nms_rotated_cpu.cpp',
                    'src/nms_rotated_ext.cpp'
                ],
                sources_cuda=[
                    'src/nms_rotated_cuda.cu',
                    'src/poly_nms_cuda.cu',
                ],
                sources_cuda_later=[
                    'src/nms_rotated_cuda.cu',
                    'src/poly_nms_cuda_1.11.cu',
                ]),
        ],
        cmdclass={'build_ext': BuildExtension},
        zip_safe=False)

添加的新文件poly_nms_cuda_1.11.cu的内容如下,具体修改的部分只是做了一些注释,修改参照的时这里

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

// #include <THC/THC.h>
// #include <THC/THCDeviceUtils.cuh>
#include <ATen/cuda/ThrustAllocator.h>
#include <ATen/ceil_div.h>

#include <vector>
#include <iostream>

#define CUDA_CHECK(condition) \
  /* Code block avoids redefinition of cudaError_t error */ \
  do { \
    cudaError_t error = condition; \
    if (error != cudaSuccess) { \
      std::cout << cudaGetErrorString(error) << std::endl; \
    } \
  } while (0)

#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long) * 8;


#define maxn 10
// const double eps=1E-8;

__device__ inline int sig(float d){
    // return(d>eps)-(d<-eps);
    return (d > 0.00000001) - (d < -0.00000001);
}

__device__ inline int point_eq(const float2 a, const float2 b) {
    return sig(a.x - b.x) == 0 && sig(a.y - b.y)==0;
}

__device__ inline void point_swap(float2 *a, float2 *b) {
    float2 temp = *a;
    *a = *b;
    *b = temp;
}

__device__ inline void point_reverse(float2 *first, float2* last)
{
    while ((first!=last)&&(first!=--last)) {
        point_swap (first,last);
        ++first;
    }
}

__device__ inline float cross(float2 o,float2 a,float2 b){  //叉积
    return(a.x-o.x)*(b.y-o.y)-(b.x-o.x)*(a.y-o.y);
}
__device__ inline float area(float2* ps,int n){
    ps[n]=ps[0];
    float res=0;
    for(int i=0;i<n;i++){
        res+=ps[i].x*ps[i+1].y-ps[i].y*ps[i+1].x;
    }
    return res/2.0;
}
__device__ inline int lineCross(float2 a,float2 b,float2 c,float2 d,float2&p){
    float s1,s2;
    s1=cross(a,b,c);
    s2=cross(a,b,d);
    if(sig(s1)==0&&sig(s2)==0) return 2;
    if(sig(s2-s1)==0) return 0;
    p.x=(c.x*s2-d.x*s1)/(s2-s1);
    p.y=(c.y*s2-d.y*s1)/(s2-s1);
    return 1;
}

__device__ inline void polygon_cut(float2*p,int&n,float2 a,float2 b, float2* pp){

    int m=0;p[n]=p[0];
    for(int i=0;i<n;i++){
        if(sig(cross(a,b,p[i]))>0) pp[m++]=p[i];
        if(sig(cross(a,b,p[i]))!=sig(cross(a,b,p[i+1])))
            lineCross(a,b,p[i],p[i+1],pp[m++]);
    }
    n=0;
    for(int i=0;i<m;i++)
        if(!i||!(point_eq(pp[i], pp[i-1])))
            p[n++]=pp[i];
    // while(n>1&&p[n-1]==p[0])n--;
    while(n>1&&point_eq(p[n-1], p[0]))n--;
}

//---------------华丽的分隔线-----------------//
//返回三角形oab和三角形ocd的有向交面积,o是原点//
__device__ inline float intersectArea(float2 a,float2 b,float2 c,float2 d){
    float2 o = make_float2(0,0);
    int s1=sig(cross(o,a,b));
    int s2=sig(cross(o,c,d));
    if(s1==0||s2==0)return 0.0;//退化,面积为0
    // if(s1==-1) swap(a,b);
    // if(s2==-1) swap(c,d);
    if (s1 == -1) point_swap(&a, &b);
    if (s2 == -1) point_swap(&c, &d);
    float2 p[10]={o,a,b};
    int n=3;
    float2 pp[maxn];
    polygon_cut(p,n,o,c,pp);
    polygon_cut(p,n,c,d,pp);
    polygon_cut(p,n,d,o,pp);
    float res=fabs(area(p,n));
    if(s1*s2==-1) res=-res;return res;
}
//求两多边形的交面积
__device__ inline float intersectArea(float2*ps1,int n1,float2*ps2,int n2){
    if(area(ps1,n1)<0) point_reverse(ps1,ps1+n1);
    if(area(ps2,n2)<0) point_reverse(ps2,ps2+n2);
    ps1[n1]=ps1[0];
    ps2[n2]=ps2[0];
    float res=0;
    for(int i=0;i<n1;i++){
        for(int j=0;j<n2;j++){
            res+=intersectArea(ps1[i],ps1[i+1],ps2[j],ps2[j+1]);
        }
    }
    return res;//assumeresispositive!
}

// TODO: optimal if by first calculate the iou between two hbbs
__device__ inline float devPolyIoU(float const * const p, float const * const q) {
    float2 ps1[maxn], ps2[maxn];
    int n1 = 4;
    int n2 = 4;
    for (int i = 0; i < 4; i++) {
        ps1[i].x = p[i * 2];
        ps1[i].y = p[i * 2 + 1];

        ps2[i].x = q[i * 2];
        ps2[i].y = q[i * 2 + 1];
    }
    float inter_area = intersectArea(ps1, n1, ps2, n2);
    float union_area = fabs(area(ps1, n1)) + fabs(area(ps2, n2)) - inter_area;
    float iou = 0;
    if (union_area == 0) {
        iou = (inter_area + 1) / (union_area + 1);
    } else {
        iou = inter_area / union_area;
    }
    return iou;
}

__global__ void poly_nms_kernel(const int n_polys, const float nms_overlap_thresh,
                            const float *dev_polys, unsigned long long *dev_mask) {
    const int row_start = blockIdx.y;
    const int col_start = blockIdx.x;

    const int row_size =
            min(n_polys - row_start * threadsPerBlock, threadsPerBlock);
    const int cols_size =
            min(n_polys - col_start * threadsPerBlock, threadsPerBlock);

    __shared__ float block_polys[threadsPerBlock * 9];
    if (threadIdx.x < cols_size) {
        block_polys[threadIdx.x * 9 + 0] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 0];
        block_polys[threadIdx.x * 9 + 1] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 1];
        block_polys[threadIdx.x * 9 + 2] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 2];
        block_polys[threadIdx.x * 9 + 3] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 3];
        block_polys[threadIdx.x * 9 + 4] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 4];
        block_polys[threadIdx.x * 9 + 5] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 5];
        block_polys[threadIdx.x * 9 + 6] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 6];
        block_polys[threadIdx.x * 9 + 7] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 7];
        block_polys[threadIdx.x * 9 + 8] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 8];
    }
    __syncthreads();

    if (threadIdx.x < row_size) {
        const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
        const float *cur_box = dev_polys + cur_box_idx * 9;
        int i = 0;
        unsigned long long t = 0;
        int start = 0;
        if (row_start == col_start) {
            start = threadIdx.x + 1;
        }
        for (i = start; i < cols_size; i++) {
            if (devPolyIoU(cur_box, block_polys + i * 9) > nms_overlap_thresh) {
                t |= 1ULL << i;
            }
        }
        // const int col_blocks = THCCeilDiv(n_polys, threadsPerBlock);
        const int col_blocks =  at::ceil_div(n_polys, threadsPerBlock);
        dev_mask[cur_box_idx * col_blocks + col_start] = t;
    }
}

// boxes is a N x 9 tensor
at::Tensor poly_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {

    at::DeviceGuard guard(boxes.device());

    using scalar_t = float;
    AT_ASSERTM(boxes.device().is_cuda(), "boxes must be a CUDA tensor");
    auto scores = boxes.select(1, 8);
    auto order_t = std::get<1>(scores.sort(0, /*descending=*/true));
    auto boxes_sorted = boxes.index_select(0, order_t);

    int boxes_num = boxes.size(0);

    // const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
    const int col_blocks =  at::ceil_div(boxes_num, threadsPerBlock);

    scalar_t* boxes_dev = boxes_sorted.data_ptr<scalar_t>();

    // THCState *state = at::globalContext().lazyInitCUDA();

    unsigned long long* mask_dev = NULL;

    // mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
    mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long));

    // dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
    //             THCCeilDiv(boxes_num, threadsPerBlock));
    dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock),
                at::ceil_div(boxes_num, threadsPerBlock));
    dim3 threads(threadsPerBlock);
    poly_nms_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(boxes_num,
                                        nms_overlap_thresh,
                                        boxes_dev,
                                        mask_dev);

    std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
    // THCudaCheck(cudaMemcpyAsync(
	//     &mask_host[0],
    //     mask_dev,
    //     sizeof(unsigned long long) * boxes_num * col_blocks,
    //     cudaMemcpyDeviceToHost,
	// 	   at::cuda::getCurrentCUDAStream()
	// ));
	C10_CUDA_CHECK(cudaMemcpyAsync(
	    &mask_host[0],
		mask_dev,
		sizeof(unsigned long long) * boxes_num * col_blocks,
		cudaMemcpyDeviceToHost,
		at::cuda::getCurrentCUDAStream()
	));

    std::vector<unsigned long long> remv(col_blocks);
    memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

    at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
    int64_t* keep_out = keep.data_ptr<int64_t>();

    int num_to_keep = 0;
    for (int i = 0; i < boxes_num; i++) {
        int nblock = i / threadsPerBlock;
        int inblock = i % threadsPerBlock;

        if (!(remv[nblock] & (1ULL << inblock))) {
            keep_out[num_to_keep++] = i;
            unsigned long long *p = &mask_host[0] + i * col_blocks;
            for (int j = nblock; j < col_blocks; j++) {
                remv[j] |= p[j];
            }
        }
    }

    // THCudaFree(state, mask_dev);
    c10::cuda::CUDACachingAllocator::raw_delete(mask_dev);

    return order_t.index({
        keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
          order_t.device(), keep.scalar_type())});
}

@Stark-Bear
Copy link

@hukaixuan19970627 我觉得可以多添加一个修改后的文件来解决这个问题,针对不同的pytorch版本编译不同的文件,我的对setup.py做了一些修改,函数make_cuda_ext多添加了一个参数sources_cuda_later会根据不同的pytorch版本,编译不同的文件。

#!/usr/bin/env python
import os
from setuptools import setup

import torch
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
                                       CUDAExtension)


def make_cuda_ext(name, module, sources, sources_cuda=[], sources_cuda_later=[]):
    define_macros = []
    extra_compile_args = {'cxx': []}

    if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
        define_macros += [('WITH_CUDA', None)]
        extension = CUDAExtension
        extra_compile_args['nvcc'] = [
            '-D__CUDA_NO_HALF_OPERATORS__',
            '-D__CUDA_NO_HALF_CONVERSIONS__',
            '-D__CUDA_NO_HALF2_OPERATORS__',
        ]
        if torch.__version__ < '1.11' or len(sources_cuda_later) == 0:
            sources += sources_cuda
        else:
            sources += sources_cuda_later
    else:
        print(f'Compiling {name} without CUDA')
        extension = CppExtension
        # raise EnvironmentError('CUDA is required to compile MMDetection!')

    return extension(
        name=f'{module}.{name}',
        sources=[os.path.join(*module.split('.'), p) for p in sources],
        define_macros=define_macros,
        extra_compile_args=extra_compile_args)


# python setup.py develop
if __name__ == '__main__':
    # write_version_py()
    setup(
        name='nms_rotated',
        ext_modules=[
            make_cuda_ext(
                name='nms_rotated_ext',
                module='',
                sources=[
                    'src/nms_rotated_cpu.cpp',
                    'src/nms_rotated_ext.cpp'
                ],
                sources_cuda=[
                    'src/nms_rotated_cuda.cu',
                    'src/poly_nms_cuda.cu',
                ],
                sources_cuda_later=[
                    'src/nms_rotated_cuda.cu',
                    'src/poly_nms_cuda_1.11.cu',
                ]),
        ],
        cmdclass={'build_ext': BuildExtension},
        zip_safe=False)

添加的新文件poly_nms_cuda_1.11.cu的内容如下,具体修改的部分只是做了一些注释,修改参照的时这里

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

// #include <THC/THC.h>
// #include <THC/THCDeviceUtils.cuh>
#include <ATen/cuda/ThrustAllocator.h>
#include <ATen/ceil_div.h>

#include <vector>
#include <iostream>

#define CUDA_CHECK(condition) \
  /* Code block avoids redefinition of cudaError_t error */ \
  do { \
    cudaError_t error = condition; \
    if (error != cudaSuccess) { \
      std::cout << cudaGetErrorString(error) << std::endl; \
    } \
  } while (0)

#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long) * 8;


#define maxn 10
// const double eps=1E-8;

__device__ inline int sig(float d){
    // return(d>eps)-(d<-eps);
    return (d > 0.00000001) - (d < -0.00000001);
}

__device__ inline int point_eq(const float2 a, const float2 b) {
    return sig(a.x - b.x) == 0 && sig(a.y - b.y)==0;
}

__device__ inline void point_swap(float2 *a, float2 *b) {
    float2 temp = *a;
    *a = *b;
    *b = temp;
}

__device__ inline void point_reverse(float2 *first, float2* last)
{
    while ((first!=last)&&(first!=--last)) {
        point_swap (first,last);
        ++first;
    }
}

__device__ inline float cross(float2 o,float2 a,float2 b){  //叉积
    return(a.x-o.x)*(b.y-o.y)-(b.x-o.x)*(a.y-o.y);
}
__device__ inline float area(float2* ps,int n){
    ps[n]=ps[0];
    float res=0;
    for(int i=0;i<n;i++){
        res+=ps[i].x*ps[i+1].y-ps[i].y*ps[i+1].x;
    }
    return res/2.0;
}
__device__ inline int lineCross(float2 a,float2 b,float2 c,float2 d,float2&p){
    float s1,s2;
    s1=cross(a,b,c);
    s2=cross(a,b,d);
    if(sig(s1)==0&&sig(s2)==0) return 2;
    if(sig(s2-s1)==0) return 0;
    p.x=(c.x*s2-d.x*s1)/(s2-s1);
    p.y=(c.y*s2-d.y*s1)/(s2-s1);
    return 1;
}

__device__ inline void polygon_cut(float2*p,int&n,float2 a,float2 b, float2* pp){

    int m=0;p[n]=p[0];
    for(int i=0;i<n;i++){
        if(sig(cross(a,b,p[i]))>0) pp[m++]=p[i];
        if(sig(cross(a,b,p[i]))!=sig(cross(a,b,p[i+1])))
            lineCross(a,b,p[i],p[i+1],pp[m++]);
    }
    n=0;
    for(int i=0;i<m;i++)
        if(!i||!(point_eq(pp[i], pp[i-1])))
            p[n++]=pp[i];
    // while(n>1&&p[n-1]==p[0])n--;
    while(n>1&&point_eq(p[n-1], p[0]))n--;
}

//---------------华丽的分隔线-----------------//
//返回三角形oab和三角形ocd的有向交面积,o是原点//
__device__ inline float intersectArea(float2 a,float2 b,float2 c,float2 d){
    float2 o = make_float2(0,0);
    int s1=sig(cross(o,a,b));
    int s2=sig(cross(o,c,d));
    if(s1==0||s2==0)return 0.0;//退化,面积为0
    // if(s1==-1) swap(a,b);
    // if(s2==-1) swap(c,d);
    if (s1 == -1) point_swap(&a, &b);
    if (s2 == -1) point_swap(&c, &d);
    float2 p[10]={o,a,b};
    int n=3;
    float2 pp[maxn];
    polygon_cut(p,n,o,c,pp);
    polygon_cut(p,n,c,d,pp);
    polygon_cut(p,n,d,o,pp);
    float res=fabs(area(p,n));
    if(s1*s2==-1) res=-res;return res;
}
//求两多边形的交面积
__device__ inline float intersectArea(float2*ps1,int n1,float2*ps2,int n2){
    if(area(ps1,n1)<0) point_reverse(ps1,ps1+n1);
    if(area(ps2,n2)<0) point_reverse(ps2,ps2+n2);
    ps1[n1]=ps1[0];
    ps2[n2]=ps2[0];
    float res=0;
    for(int i=0;i<n1;i++){
        for(int j=0;j<n2;j++){
            res+=intersectArea(ps1[i],ps1[i+1],ps2[j],ps2[j+1]);
        }
    }
    return res;//assumeresispositive!
}

// TODO: optimal if by first calculate the iou between two hbbs
__device__ inline float devPolyIoU(float const * const p, float const * const q) {
    float2 ps1[maxn], ps2[maxn];
    int n1 = 4;
    int n2 = 4;
    for (int i = 0; i < 4; i++) {
        ps1[i].x = p[i * 2];
        ps1[i].y = p[i * 2 + 1];

        ps2[i].x = q[i * 2];
        ps2[i].y = q[i * 2 + 1];
    }
    float inter_area = intersectArea(ps1, n1, ps2, n2);
    float union_area = fabs(area(ps1, n1)) + fabs(area(ps2, n2)) - inter_area;
    float iou = 0;
    if (union_area == 0) {
        iou = (inter_area + 1) / (union_area + 1);
    } else {
        iou = inter_area / union_area;
    }
    return iou;
}

__global__ void poly_nms_kernel(const int n_polys, const float nms_overlap_thresh,
                            const float *dev_polys, unsigned long long *dev_mask) {
    const int row_start = blockIdx.y;
    const int col_start = blockIdx.x;

    const int row_size =
            min(n_polys - row_start * threadsPerBlock, threadsPerBlock);
    const int cols_size =
            min(n_polys - col_start * threadsPerBlock, threadsPerBlock);

    __shared__ float block_polys[threadsPerBlock * 9];
    if (threadIdx.x < cols_size) {
        block_polys[threadIdx.x * 9 + 0] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 0];
        block_polys[threadIdx.x * 9 + 1] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 1];
        block_polys[threadIdx.x * 9 + 2] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 2];
        block_polys[threadIdx.x * 9 + 3] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 3];
        block_polys[threadIdx.x * 9 + 4] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 4];
        block_polys[threadIdx.x * 9 + 5] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 5];
        block_polys[threadIdx.x * 9 + 6] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 6];
        block_polys[threadIdx.x * 9 + 7] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 7];
        block_polys[threadIdx.x * 9 + 8] =
            dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 8];
    }
    __syncthreads();

    if (threadIdx.x < row_size) {
        const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
        const float *cur_box = dev_polys + cur_box_idx * 9;
        int i = 0;
        unsigned long long t = 0;
        int start = 0;
        if (row_start == col_start) {
            start = threadIdx.x + 1;
        }
        for (i = start; i < cols_size; i++) {
            if (devPolyIoU(cur_box, block_polys + i * 9) > nms_overlap_thresh) {
                t |= 1ULL << i;
            }
        }
        // const int col_blocks = THCCeilDiv(n_polys, threadsPerBlock);
        const int col_blocks =  at::ceil_div(n_polys, threadsPerBlock);
        dev_mask[cur_box_idx * col_blocks + col_start] = t;
    }
}

// boxes is a N x 9 tensor
at::Tensor poly_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {

    at::DeviceGuard guard(boxes.device());

    using scalar_t = float;
    AT_ASSERTM(boxes.device().is_cuda(), "boxes must be a CUDA tensor");
    auto scores = boxes.select(1, 8);
    auto order_t = std::get<1>(scores.sort(0, /*descending=*/true));
    auto boxes_sorted = boxes.index_select(0, order_t);

    int boxes_num = boxes.size(0);

    // const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
    const int col_blocks =  at::ceil_div(boxes_num, threadsPerBlock);

    scalar_t* boxes_dev = boxes_sorted.data_ptr<scalar_t>();

    // THCState *state = at::globalContext().lazyInitCUDA();

    unsigned long long* mask_dev = NULL;

    // mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
    mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long));

    // dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
    //             THCCeilDiv(boxes_num, threadsPerBlock));
    dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock),
                at::ceil_div(boxes_num, threadsPerBlock));
    dim3 threads(threadsPerBlock);
    poly_nms_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(boxes_num,
                                        nms_overlap_thresh,
                                        boxes_dev,
                                        mask_dev);

    std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
    // THCudaCheck(cudaMemcpyAsync(
	//     &mask_host[0],
    //     mask_dev,
    //     sizeof(unsigned long long) * boxes_num * col_blocks,
    //     cudaMemcpyDeviceToHost,
	// 	   at::cuda::getCurrentCUDAStream()
	// ));
	C10_CUDA_CHECK(cudaMemcpyAsync(
	    &mask_host[0],
		mask_dev,
		sizeof(unsigned long long) * boxes_num * col_blocks,
		cudaMemcpyDeviceToHost,
		at::cuda::getCurrentCUDAStream()
	));

    std::vector<unsigned long long> remv(col_blocks);
    memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

    at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
    int64_t* keep_out = keep.data_ptr<int64_t>();

    int num_to_keep = 0;
    for (int i = 0; i < boxes_num; i++) {
        int nblock = i / threadsPerBlock;
        int inblock = i % threadsPerBlock;

        if (!(remv[nblock] & (1ULL << inblock))) {
            keep_out[num_to_keep++] = i;
            unsigned long long *p = &mask_host[0] + i * col_blocks;
            for (int j = nblock; j < col_blocks; j++) {
                remv[j] |= p[j];
            }
        }
    }

    // THCudaFree(state, mask_dev);
    c10::cuda::CUDACachingAllocator::raw_delete(mask_dev);

    return order_t.index({
        keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
          order_t.device(), keep.scalar_type())});
}

Thanks, it's really useful method while I'm working through a docker containter encountering this error. my image environment is : " python 3.9 , torch 2.1.0+cu121 , torchaudio 2.1.0+cu121, torchvision 0.16.0+cu121" . And the cuda toolkit version installed on my linux server is cuda 12.2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants