diff --git a/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp b/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp index e4792279..db265bde 100644 --- a/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp @@ -23,45 +23,46 @@ std::tuple segment_max_cpu_forward( auto sizes = x.sizes().vec(); sizes[0] = N; - torch::Tensor out = torch::empty(sizes, x.options()); + torch::Tensor out = torch::zeros(sizes, x.options()); torch::Tensor arg_out = torch::full_like(out, out.size(0), index.options()); if (x.numel() == 0) { - out.fill_(0.); return std::make_tuple(out, arg_out); } - out.fill_(std::numeric_limits::lowest()); auto E = x.size(0); - auto K = x.numel() / x.size(0); + auto K = x.numel() / E; auto index_data = index.data_ptr(); auto arg_out_data = arg_out.data_ptr(); - using scalar_t = float; - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_mean_cpu_forward", [&]() { + out.fill_(std::numeric_limits::lowest()); + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); - int64_t idx; -#ifdef COMPILE_WITH_OMP -#pragma omp parallel for private(idx) -#endif - for (auto e = 0; e < E; ++e) { - idx = index_data[e]; - TORCH_CHECK_INDEX(idx < N, "Index out of bounds: ", idx, " >= ", N); - for (auto k = 0; k < K; ++k) { - scalar_t current_val = x_data[e * K + k]; - scalar_t& max_val = out_data[idx * K + k]; - int64_t& max_idx = arg_out_data[idx * K + k]; -#ifdef COMPILE_WITH_OMP -#pragma omp critical -#endif - { - if (max_val < current_val) { - max_val = current_val; - max_idx = e; + int64_t idx; + #ifdef COMPILE_WITH_OMP + #pragma omp parallel for private(idx) + #endif + for (auto e = 0; e < E; ++e) { + idx = index_data[e]; + TORCH_CHECK_INDEX(idx < N, "Index out of bounds: ", idx, " >= ", N); + for (auto k = 0; k < K; ++k) { + scalar_t current_val = x_data[e * K + k]; + scalar_t& max_val = out_data[idx * K + k]; + int64_t& max_idx = arg_out_data[idx * K + k]; + #ifdef COMPILE_WITH_OMP + #pragma omp critical + #endif + { + if (max_val < current_val) { + max_val = current_val; + max_idx = e; + } + } } } - } - } + + }); return std::make_tuple(out, arg_out); } diff --git a/gammagl/mpops/torch_ext/cpu/segment_mean_cpu.cpp b/gammagl/mpops/torch_ext/cpu/segment_mean_cpu.cpp index 7d8610a7..98f2d1da 100644 --- a/gammagl/mpops/torch_ext/cpu/segment_mean_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/segment_mean_cpu.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -30,49 +31,50 @@ torch::Tensor segment_mean_cpu_forward( } auto E = x.size(0); - auto K = x.numel() / x.size(0); + auto K = x.numel() / E; auto index_data = index.data_ptr(); auto arg_out_data = arg_out.data_ptr(); - // AT_DISPATCH_ALL_TYPES(x.scalar_type(), "__ops_name", [&] { - using scalar_t = float; - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_mean_cpu_forward", [&]() { - torch::Tensor degree = torch::zeros({1, index.size(0)}, x.options()); - auto degree_data = degree.data_ptr(); + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); -#ifdef COMPILE_WITH_OMP -#pragma omp parallel for -#endif - for (auto e = 0; e < E; ++e) { - auto idx = index_data[e]; - degree_data[idx] += 1; - for (auto k = 0; k < K; ++k) { -#ifdef COMPILE_WITH_OMP -#pragma omp critical -#endif - out_data[idx * K + k] += x_data[e * K + k]; - arg_out_data[idx * K + k] = e; - } - } - // }); - out = out.contiguous(); - degree = degree.contiguous(); + torch::Tensor degree = torch::zeros({1, index.size(0)}, x.options()); + auto degree_data = degree.data_ptr(); -#ifdef COMPILE_WITH_OMP -#pragma omp parallel for -#endif - for (auto e = 0; e < E; ++e) { - if (degree_data[e] > 1) { - for (auto k = 0; k < K; ++k) { -#ifdef COMPILE_WITH_OMP -#pragma omp critical -#endif - out_data[e * K + k] /= degree_data[e]; + #ifdef COMPILE_WITH_OMP + #pragma omp parallel for + #endif + for (auto e = 0; e < E; ++e) { + auto idx = index_data[e]; + degree_data[idx] += 1; + for (auto k = 0; k < K; ++k) { + #ifdef COMPILE_WITH_OMP + #pragma omp critical + #endif + out_data[idx * K + k] += x_data[e * K + k]; + arg_out_data[idx * K + k] = e; + } } - } - } + out = out.contiguous(); + degree = degree.contiguous(); + + #ifdef COMPILE_WITH_OMP + #pragma omp parallel for + #endif + for (auto e = 0; e < E; ++e) { + if (degree_data[e] > 1) { + for (auto k = 0; k < K; ++k) { + #ifdef COMPILE_WITH_OMP + #pragma omp critical + #endif + out_data[e * K + k] /= degree_data[e]; + } + } + } + + }); return out; } diff --git a/gammagl/mpops/torch_ext/cpu/segment_sum_cpu.cpp b/gammagl/mpops/torch_ext/cpu/segment_sum_cpu.cpp index 0cf5d2c3..5d38478f 100644 --- a/gammagl/mpops/torch_ext/cpu/segment_sum_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/segment_sum_cpu.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -13,7 +14,6 @@ torch::Tensor segment_sum_cpu_forward( TORCH_CHECK(index.device().is_cpu(), "index must be CPU tensor"); TORCH_CHECK_INDEX( index.dim() == 1, "index dimension should be 1, but got ", index.dim()); - TORCH_CHECK_INDEX( x.size(0) == index.size(0), "fisrt dimension of x and index should be same"); @@ -21,42 +21,38 @@ torch::Tensor segment_sum_cpu_forward( x = x.contiguous(); // Make sure x is contiguous. index = index.contiguous(); - // Set up the sizes for the output tensor. auto sizes = x.sizes().vec(); sizes[0] = N; - - // Initialize the output tensor with zeros. torch::Tensor out = torch::zeros(sizes, x.options()); - // If there is no element in x, return the output tensors as they are. if (x.numel() == 0) { return out; } - // Get data pointers for index, arg_out, and x. - auto index_data = index.data_ptr(); - auto x_data = x.data_ptr(); // Assuming x is of type float. - auto out_data = out.data_ptr(); - - // Set up dimensions for iteration. - auto E = index.size(0); // Number of elements to process. - // auto K = (x.dim() > 1) ? x.size(1) : 1; // Size of the inner dimension. - auto K = x.numel() / x.size(0); // Size of the inner dimension. - -#ifdef COMPILE_WITH_OMP -#pragma omp parallel for -#endif - // Iterate over each element in x. - for (auto e = 0; e < E; ++e) { - auto idx = index_data[e]; - // Handle accumulation for different dimensions. - for (auto k = 0; k < K; ++k) { -#ifdef COMPILE_WITH_OMP -#pragma omp critical -#endif - out_data[idx * K + k] += x_data[e * K + k]; - } - } + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_sum_cpu_forward", [&]() { + // Get data pointers for index, arg_out, and x. + auto index_data = index.data_ptr(); + auto x_data = x.data_ptr(); // Assuming x is of type float. + auto out_data = out.data_ptr(); + + auto E = index.size(0); // Number of elements to process. + auto K = x.numel() / x.size(0); // Size of the inner dimension. + + #ifdef COMPILE_WITH_OMP + #pragma omp parallel for + #endif + // Iterate over each element in x. + for (auto e = 0; e < E; ++e) { + auto idx = index_data[e]; + // Handle accumulation for different dimensions. + for (auto k = 0; k < K; ++k) { + #ifdef COMPILE_WITH_OMP + #pragma omp critical + #endif + out_data[idx * K + k] += x_data[e * K + k]; + } + } + }); return out; } diff --git a/gammagl/mpops/torch_ext/cuda/segment_max_cuda.cu b/gammagl/mpops/torch_ext/cuda/segment_max_cuda.cu index a8f52686..e465e944 100644 --- a/gammagl/mpops/torch_ext/cuda/segment_max_cuda.cu +++ b/gammagl/mpops/torch_ext/cuda/segment_max_cuda.cu @@ -17,16 +17,47 @@ using torch::autograd::variable_list; #define THREADS 1024 #define BLOCKS(N) (N + THREADS - 1) / THREADS -inline __device__ void atomic_max_float(float *addr, float value) { - int *addr_as_i = (int *)addr; - int old = *addr_as_i; - int assumed; - do { - assumed = old; - old = atomicCAS( - addr_as_i, assumed, - __float_as_int(max(value, __int_as_float(assumed)))); - } while (assumed != old); +// template +// __device__ void atomic_max_float(scalar_t *addr, scalar_t value) { +// int *addr_as_i = (int *)addr; +// int old = *addr_as_i; +// int assumed; +// do { +// assumed = old; +// old = atomicCAS( +// addr_as_i, assumed, +// __float_as_int(max(value, __int_as_float(assumed)))); +// } while (assumed != old); +// } + +template +__device__ void atomic_max(scalar_t* const address, const scalar_t value); + +template <> +__device__ void atomic_max(int32_t* const address, const int32_t value) { + atomicMax(address, value); +} + +template <> +__device__ void atomic_max(float* const address, const float value) { + int* const address_as_i = (int*)address; + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS(address_as_i, assumed, + __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); +} + +template <> +__device__ void atomic_max(double* const address, const double value) { + unsigned long long int* const address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(fmax(value, __longlong_as_double(assumed)))); + } while (assumed != old); } template @@ -39,7 +70,8 @@ __global__ void segment_max_cuda_forward_kernel( if (thread_idx < numel) { // TODO: support more data type int64_t idx = index_data[e]; - atomic_max_float(out_data + idx * K + k, x_data[thread_idx]); + // atomic_max_float(out_data + idx * K + k, x_data[thread_idx]); + atomic_max(out_data + idx * K + k, x_data[thread_idx]); } } @@ -89,8 +121,8 @@ std::tuple segment_max_cuda_forward( x.size(0) == index.size(0), "fisrt dimension of x and index should be same"); // only support float Tensor - TORCH_CHECK_TYPE( - x.scalar_type() == c10::ScalarType::Float, "x should be float Tensor") + // TORCH_CHECK_TYPE( + // x.scalar_type() == c10::ScalarType::Float, "x should be float Tensor") cudaSetDevice(x.get_device()); x = x.contiguous(); index = index.contiguous(); @@ -108,29 +140,103 @@ std::tuple segment_max_cuda_forward( return std::make_tuple(out, arg_out); } - out.fill_(std::numeric_limits::lowest()); + // out.fill_(std::numeric_limits::lowest()); auto E = x.size(0); auto K = x.numel() / x.size(0); auto stream = at::cuda::getCurrentCUDAStream(); // AT_DISPATCH_ALL_TYPES(x.scalar_type(), "__ops_name", [&] { - using scalar_t = float; // temporary usage, delete later - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); - auto index_data = index.data_ptr(); - - segment_max_cuda_forward_kernel - <<>>( - x_data, index_data, out_data, E, K, N, x.numel()); - - // out.masked_fill_(out == std::numeric_limits::lowest(), - // (scalar_t)0); - - arg_segment_max_cuda_forward_kernel - <<>>( - x_data, index_data, out_data, arg_out_data, E, K, N, x.numel(), - out.size(0)); + // using scalar_t = float; // temporary usage, delete later + // auto x_data = x.data_ptr(); + // auto out_data = out.data_ptr(); + // auto index_data = index.data_ptr(); + + // segment_max_cuda_forward_kernel + // <<>>( + // x_data, index_data, out_data, E, K, N, x.numel()); + + // // out.masked_fill_(out == std::numeric_limits::lowest(), + // // (scalar_t)0); + + // arg_segment_max_cuda_forward_kernel + // <<>>( + // x_data, index_data, out_data, arg_out_data, E, K, N, x.numel(), + // out.size(0)); // }); + + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt32 || x.dtype() == torch::kInt64) { + if (x.dtype() == torch::kInt8){ + out.fill_(std::numeric_limits::lowest()); + } else if (x.dtype() == torch::kInt16){ + out.fill_(std::numeric_limits::lowest()); + } else if (x.dtype() == torch::kInt32){ + out.fill_(std::numeric_limits::lowest()); + } else if (x.dtype() == torch::kInt64){ + out.fill_(std::numeric_limits::lowest()); + } + auto type = x.dtype(); + using scalar_t = int; + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt64) { + x = x.to(torch::kInt32); + out = out.to(torch::kInt32); + } + // out.fill_(std::numeric_limits::lowest()); + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto index_data = index.data_ptr(); + + segment_max_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, E, K, N, x.numel()); + + arg_segment_max_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, arg_out_data, E, K, N, x.numel(), + out.size(0)); + + out = out.to(type); + + } else if (x.dtype() == torch::kFloat16 || x.dtype() == torch::kFloat32) { + auto type = x.dtype(); + using scalar_t = float; + if (x.dtype() == torch::kFloat16) { + x = x.to(torch::kFloat32); + out = out.to(torch::kFloat32); + out.fill_(-65503.9); + } else if (x.dtype() == torch::kFloat32) { + out.fill_(std::numeric_limits::lowest()); + } + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto index_data = index.data_ptr(); + + segment_max_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, E, K, N, x.numel()); + + arg_segment_max_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, arg_out_data, E, K, N, x.numel(), + out.size(0)); + + out = out.to(type); + } else if (x.dtype() == torch::kFloat64) { + using scalar_t = double; + out.fill_(std::numeric_limits::lowest()); + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto index_data = index.data_ptr(); + + segment_max_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, E, K, N, x.numel()); + + arg_segment_max_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, arg_out_data, E, K, N, x.numel(), + out.size(0)); + } + return std::make_tuple(out, arg_out); } diff --git a/gammagl/mpops/torch_ext/cuda/segment_mean_cuda.cu b/gammagl/mpops/torch_ext/cuda/segment_mean_cuda.cu index 1ec2ac9a..a462c3ba 100644 --- a/gammagl/mpops/torch_ext/cuda/segment_mean_cuda.cu +++ b/gammagl/mpops/torch_ext/cuda/segment_mean_cuda.cu @@ -75,8 +75,8 @@ torch::Tensor segment_mean_cuda_forward( x.size(0) == index.size(0), "fisrt dimension of x and index should be same"); // only support float Tensor - TORCH_CHECK_TYPE( - x.scalar_type() == c10::ScalarType::Float, "x should be float Tensor") + // TORCH_CHECK_TYPE( + // x.scalar_type() == c10::ScalarType::Float, "x should be float Tensor") cudaSetDevice(x.get_device()); x = x.contiguous(); index = index.contiguous(); @@ -100,23 +100,92 @@ torch::Tensor segment_mean_cuda_forward( auto stream = at::cuda::getCurrentCUDAStream(); // AT_DISPATCH_ALL_TYPES(x.scalar_type(), "__ops_name", [&] { - using scalar_t = float; // temporary usage, delete later - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); - auto index_data = index.data_ptr(); - - torch::Tensor count = torch::full_like(out, 0.0, x.options()); - scalar_t *count_data = count.data_ptr(); - - segment_mean_cuda_forward_kernel - <<>>( - x_data, index_data, out_data, count_data, E, K, N, x.numel()); - - arg_segment_mean_cuda_forward_kernel - <<>>( - x_data, index_data, out_data, arg_out_data, count_data, E, K, N, - x.numel()); + // using scalar_t = float; // temporary usage, delete later + // auto x_data = x.data_ptr(); + // auto out_data = out.data_ptr(); + // auto index_data = index.data_ptr(); + + // torch::Tensor count = torch::full_like(out, 0.0, x.options()); + // scalar_t *count_data = count.data_ptr(); + + // segment_mean_cuda_forward_kernel + // <<>>( + // x_data, index_data, out_data, count_data, E, K, N, x.numel()); + + // arg_segment_mean_cuda_forward_kernel + // <<>>( + // x_data, index_data, out_data, arg_out_data, count_data, E, K, N, + // x.numel()); // }); + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt32 || x.dtype() == torch::kInt64) { + auto type = x.dtype(); + using scalar_t = int; + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt64) { + x = x.to(torch::kInt32); + out = out.to(torch::kInt32); + } + // using scalar_t = float; // temporary usage, delete later + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto index_data = index.data_ptr(); + + torch::Tensor count = torch::full_like(out, 0.0, x.options()); + scalar_t *count_data = count.data_ptr(); + + segment_mean_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, count_data, E, K, N, x.numel()); + + arg_segment_mean_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, arg_out_data, count_data, E, K, N, + x.numel()); + + out = out.to(type); + } else if (x.dtype() == torch::kFloat16 || x.dtype() == torch::kFloat32) { + auto type = x.dtype(); + using scalar_t = float; + if (x.dtype() == torch::kFloat16) { + x = x.to(torch::kFloat32); + out = out.to(torch::kFloat32); + } + + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto index_data = index.data_ptr(); + + torch::Tensor count = torch::full_like(out, 0.0, x.options()); + scalar_t *count_data = count.data_ptr(); + + segment_mean_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, count_data, E, K, N, x.numel()); + + arg_segment_mean_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, arg_out_data, count_data, E, K, N, + x.numel()); + + out = out.to(type); + } else if (x.dtype() == torch::kFloat64) { + using scalar_t = double; + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto index_data = index.data_ptr(); + + torch::Tensor count = torch::full_like(out, 0.0, x.options()); + scalar_t *count_data = count.data_ptr(); + + segment_mean_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, count_data, E, K, N, x.numel()); + + arg_segment_mean_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, arg_out_data, count_data, E, K, N, + x.numel()); + } + return out; } diff --git a/gammagl/mpops/torch_ext/cuda/segment_sum_cuda.cu b/gammagl/mpops/torch_ext/cuda/segment_sum_cuda.cu index 174bf59a..c8eab5e7 100644 --- a/gammagl/mpops/torch_ext/cuda/segment_sum_cuda.cu +++ b/gammagl/mpops/torch_ext/cuda/segment_sum_cuda.cu @@ -41,8 +41,8 @@ torch::Tensor segment_sum_cuda_forward( x.size(0) == index.size(0), "fisrt dimension of x and index should be same"); // only support float Tensor - TORCH_CHECK_TYPE( - x.scalar_type() == c10::ScalarType::Float, "x should be float Tensor") + // TORCH_CHECK_TYPE( + // x.scalar_type() == c10::ScalarType::Float, "x should be float Tensor") cudaSetDevice(x.get_device()); x = x.contiguous(); index = index.contiguous(); @@ -67,15 +67,61 @@ torch::Tensor segment_sum_cuda_forward( auto stream = at::cuda::getCurrentCUDAStream(); // AT_DISPATCH_ALL_TYPES(x.scalar_type(), "__ops_name", [&] { - using scalar_t = float; // temporary usage, delete later - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); - auto index_data = index.data_ptr(); - - segment_sum_cuda_forward_kernel - <<>>( - x_data, index_data, out_data, E, K, N, x.numel()); + // using scalar_t = float; // temporary usage, delete later + // using scalar_t = x.scalar_type(); // temporary usage, delete later + // auto x_data = x.data_ptr(); + // auto out_data = out.data_ptr(); + // auto index_data = index.data_ptr(); + + // segment_sum_cuda_forward_kernel + // <<>>( + // x_data, index_data, out_data, E, K, N, x.numel()); // }); + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt32 || x.dtype() == torch::kInt64) { + auto type = x.dtype(); + using scalar_t = int; + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt64) { + x = x.to(torch::kInt32); + out = out.to(torch::kInt32); + } + std::cout << x.dtype() << std::endl; + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto index_data = index.data_ptr(); + + segment_sum_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, E, K, N, x.numel()); + + out = out.to(type); + } else if (x.dtype() == torch::kFloat16 || x.dtype() == torch::kFloat32) { + auto type = x.dtype(); + using scalar_t = float; + if (x.dtype() == torch::kFloat16) { + x = x.to(torch::kFloat32); + out = out.to(torch::kFloat32); + } + + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto index_data = index.data_ptr(); + + segment_sum_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, E, K, N, x.numel()); + + out = out.to(type); + } else if (x.dtype() == torch::kFloat64) { + using scalar_t = double; + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto index_data = index.data_ptr(); + + segment_sum_cuda_forward_kernel + <<>>( + x_data, index_data, out_data, E, K, N, x.numel()); + } + return out; } diff --git a/gammagl/mpops/torch_ext/include/utils.h b/gammagl/mpops/torch_ext/include/utils.h index 6fa5cf19..99fcf613 100644 --- a/gammagl/mpops/torch_ext/include/utils.h +++ b/gammagl/mpops/torch_ext/include/utils.h @@ -5,4 +5,4 @@ #include #include -std::vector list2vec(const c10::List list); \ No newline at end of file +std::vector list2vec(const c10::List list); diff --git a/gammagl/mpops/torch_ext/src/utils.cpp b/gammagl/mpops/torch_ext/src/utils.cpp index 20df561b..531c26ac 100644 --- a/gammagl/mpops/torch_ext/src/utils.cpp +++ b/gammagl/mpops/torch_ext/src/utils.cpp @@ -12,4 +12,4 @@ std::vector list2vec(const c10::List list) { result.reserve(list.size()); for (size_t i = 0; i < list.size(); ++i) result.push_back(list[i]); return result; -} \ No newline at end of file +}