Skip to content

Commit

Permalink
support more data type
Browse files Browse the repository at this point in the history
  • Loading branch information
gyzhou2000 committed Apr 1, 2024
1 parent dda2c24 commit 6279fd5
Show file tree
Hide file tree
Showing 8 changed files with 372 additions and 152 deletions.
55 changes: 28 additions & 27 deletions gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,46 @@ std::tuple<torch::Tensor, torch::Tensor> segment_max_cpu_forward(

auto sizes = x.sizes().vec();
sizes[0] = N;
torch::Tensor out = torch::empty(sizes, x.options());
torch::Tensor out = torch::zeros(sizes, x.options());
torch::Tensor arg_out = torch::full_like(out, out.size(0), index.options());
if (x.numel() == 0) {
out.fill_(0.);
return std::make_tuple(out, arg_out);
}

out.fill_(std::numeric_limits<int64_t>::lowest());
auto E = x.size(0);
auto K = x.numel() / x.size(0);
auto K = x.numel() / E;
auto index_data = index.data_ptr<int64_t>();
auto arg_out_data = arg_out.data_ptr<int64_t>();

using scalar_t = float;
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_mean_cpu_forward", [&]() {
out.fill_(std::numeric_limits<scalar_t>::lowest());
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

int64_t idx;
#ifdef COMPILE_WITH_OMP
#pragma omp parallel for private(idx)
#endif
for (auto e = 0; e < E; ++e) {
idx = index_data[e];
TORCH_CHECK_INDEX(idx < N, "Index out of bounds: ", idx, " >= ", N);
for (auto k = 0; k < K; ++k) {
scalar_t current_val = x_data[e * K + k];
scalar_t& max_val = out_data[idx * K + k];
int64_t& max_idx = arg_out_data[idx * K + k];
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
{
if (max_val < current_val) {
max_val = current_val;
max_idx = e;
int64_t idx;
#ifdef COMPILE_WITH_OMP
#pragma omp parallel for private(idx)
#endif
for (auto e = 0; e < E; ++e) {
idx = index_data[e];
TORCH_CHECK_INDEX(idx < N, "Index out of bounds: ", idx, " >= ", N);
for (auto k = 0; k < K; ++k) {
scalar_t current_val = x_data[e * K + k];
scalar_t& max_val = out_data[idx * K + k];
int64_t& max_idx = arg_out_data[idx * K + k];
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
{
if (max_val < current_val) {
max_val = current_val;
max_idx = e;
}
}
}
}
}
}

});

return std::make_tuple(out, arg_out);
}
74 changes: 38 additions & 36 deletions gammagl/mpops/torch_ext/cpu/segment_mean_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/extension.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <ATen/ATen.h>

#include <iostream>
#include <vector>
Expand All @@ -30,49 +31,50 @@ torch::Tensor segment_mean_cpu_forward(
}

auto E = x.size(0);
auto K = x.numel() / x.size(0);
auto K = x.numel() / E;
auto index_data = index.data_ptr<int64_t>();
auto arg_out_data = arg_out.data_ptr<int64_t>();

// AT_DISPATCH_ALL_TYPES(x.scalar_type(), "__ops_name", [&] {
using scalar_t = float;
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_mean_cpu_forward", [&]() {

torch::Tensor degree = torch::zeros({1, index.size(0)}, x.options());
auto degree_data = degree.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
auto idx = index_data[e];
degree_data[idx] += 1;
for (auto k = 0; k < K; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
out_data[idx * K + k] += x_data[e * K + k];
arg_out_data[idx * K + k] = e;
}
}
// });
out = out.contiguous();
degree = degree.contiguous();
torch::Tensor degree = torch::zeros({1, index.size(0)}, x.options());
auto degree_data = degree.data_ptr<scalar_t>();

#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
if (degree_data[e] > 1) {
for (auto k = 0; k < K; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
out_data[e * K + k] /= degree_data[e];
#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
auto idx = index_data[e];
degree_data[idx] += 1;
for (auto k = 0; k < K; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
out_data[idx * K + k] += x_data[e * K + k];
arg_out_data[idx * K + k] = e;
}
}
}
}
out = out.contiguous();
degree = degree.contiguous();

#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
if (degree_data[e] > 1) {
for (auto k = 0; k < K; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
out_data[e * K + k] /= degree_data[e];
}
}
}

});

return out;
}
54 changes: 25 additions & 29 deletions gammagl/mpops/torch_ext/cpu/segment_sum_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/extension.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <ATen/ATen.h>

#include <iostream>
#include <vector>
Expand All @@ -13,50 +14,45 @@ torch::Tensor segment_sum_cpu_forward(
TORCH_CHECK(index.device().is_cpu(), "index must be CPU tensor");
TORCH_CHECK_INDEX(
index.dim() == 1, "index dimension should be 1, but got ", index.dim());

TORCH_CHECK_INDEX(
x.size(0) == index.size(0),
"fisrt dimension of x and index should be same");

x = x.contiguous(); // Make sure x is contiguous.
index = index.contiguous();

// Set up the sizes for the output tensor.
auto sizes = x.sizes().vec();
sizes[0] = N;

// Initialize the output tensor with zeros.
torch::Tensor out = torch::zeros(sizes, x.options());

// If there is no element in x, return the output tensors as they are.
if (x.numel() == 0) {
return out;
}

// Get data pointers for index, arg_out, and x.
auto index_data = index.data_ptr<int64_t>();
auto x_data = x.data_ptr<float>(); // Assuming x is of type float.
auto out_data = out.data_ptr<float>();

// Set up dimensions for iteration.
auto E = index.size(0); // Number of elements to process.
// auto K = (x.dim() > 1) ? x.size(1) : 1; // Size of the inner dimension.
auto K = x.numel() / x.size(0); // Size of the inner dimension.

#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
// Iterate over each element in x.
for (auto e = 0; e < E; ++e) {
auto idx = index_data[e];
// Handle accumulation for different dimensions.
for (auto k = 0; k < K; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
out_data[idx * K + k] += x_data[e * K + k];
}
}
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_sum_cpu_forward", [&]() {
// Get data pointers for index, arg_out, and x.
auto index_data = index.data_ptr<int64_t>();
auto x_data = x.data_ptr<scalar_t>(); // Assuming x is of type float.
auto out_data = out.data_ptr<scalar_t>();

auto E = index.size(0); // Number of elements to process.
auto K = x.numel() / x.size(0); // Size of the inner dimension.

#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
// Iterate over each element in x.
for (auto e = 0; e < E; ++e) {
auto idx = index_data[e];
// Handle accumulation for different dimensions.
for (auto k = 0; k < K; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
out_data[idx * K + k] += x_data[e * K + k];
}
}
});

return out;
}
Loading

0 comments on commit 6279fd5

Please sign in to comment.