Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

add logical_right_shift op #1083

Merged
merged 8 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ NETBUILDER_BINARY_OP_DEF(Equal, equal);
NETBUILDER_BINARY_OP_DEF(NotEqual, not_equal);
NETBUILDER_BINARY_OP_DEF(GreaterEqual, greater_equal);
NETBUILDER_BINARY_OP_DEF(LessEqual, less_equal);
NETBUILDER_BINARY_OP_DEF(LogicalRightShift, logical_right_shift);

#undef NETBUILDER_BINARY_OP_DEF

Expand Down
3 changes: 2 additions & 1 deletion cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ namespace frontend {
macro__(GreaterThan) \
macro__(LessThan) \
macro__(GreaterEqual) \
macro__(LessEqual)
macro__(LessEqual) \
macro__(LogicalRightShift)

// ******************************************* //
// Reduce array elements over the given dims.
Expand Down
2 changes: 2 additions & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ gather_srcs(cinnapi_src SRCS
clz.cc
popc.cc
reciprocal.cc
logical_right_shift.cc
)

cc_test(test_gather_nd SRCS gather_nd_test.cc DEPS cinncore)
Expand All @@ -29,3 +30,4 @@ cc_test(test_cbrt SRCS cbrt_test.cc DEPS cinncore)
cc_test(test_clz SRCS clz_test.cc DEPS cinncore)
cc_test(test_popc SRCS popc_test.cc DEPS cinncore)
cc_test(test_reciprocal SRCS reciprocal_test.cc DEPS cinncore)
cc_test(test_logical_right_shift SRCS logical_right_shift_test.cc DEPS cinncore)
156 changes: 156 additions & 0 deletions cinn/hlir/op/contrib/logical_right_shift.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cinn/common/cas.h"
#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/macros.h"
#include "cinn/common/target.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/pe/ir_schedule_pe.h"
#include "cinn/hlir/pe/nn.h"
#include "cinn/hlir/pe/schedule.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/ir_operators.h"
#include "cinn/ir/tensor.h"
#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"
#include "gflags/gflags.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace hlir {
namespace op {

using common::_CINNValuePack_;
using common::CINNValue;
using common::CINNValuePack;
using framework::OpStrategy;
using framework::shape_t;
using framework::StrategyFunction;

ir::Tensor LogicalRightShift(const ir::Tensor &A,
const ir::Tensor &B,
const Target &target,
const std::string &output_name) {
std::string extern_func = "cinn_";
if (target == common::DefaultHostTarget()) {
extern_func += "host_";
} else if (target == common::DefaultNVGPUTarget()) {
extern_func += "nvgpu_";
} else {
CINN_NOT_IMPLEMENTED
}

extern_func += "logical_right_shift";

if (A->type().is_int(32) || A->type().is_uint(32)) {
extern_func += "_int32";
} else {
CINN_NOT_IMPLEMENTED
}

return Compute(
A->shape,
[=](const std::vector<Expr> &indices) {
Expr x = A(indices);
Expr y = B(indices);
return lang::CallExtern(extern_func, {x, y});
},
output_name);
}

std::shared_ptr<OpStrategy> StrategyForLogicalRightShift(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
std::string op_name("logical_right_shift");

framework::CINNCompute logical_right_shift_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check.\n";
CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 2U) << "2 input tensors for " << op_name << " compute\n";

Expr A_expr = pack_args[0];
Expr B_expr = pack_args[1];
CHECK(A_expr.as_tensor());
CHECK(B_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor B = B_expr.as_tensor_ref();

std::string tensor_name = UniqName("T_LogicalRightShift_out");

if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(pack_args.size(), 3U);
tensor_name = pack_args[2].operator std::string();
}

auto out = LogicalRightShift(A, B, target, tensor_name);
auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(logical_right_shift_compute,
framework::GetInjectiveScheduleFunc(output_shapes, target),
"strategy.logical_right_shift.x86",
1);
return strategy;
}

std::vector<framework::shape_t> InferShapeForLogicalRightShift(const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again.";
CHECK_EQ(inputs_shape[0].size(), inputs_shape[1].size()) << "The inputs' dims should be equal.";
std::vector<framework::shape_t> res{inputs_shape[0]};
return res;
}

std::vector<Type> InferDtypeForLogicalRightShift(const std::vector<Type> &inputs_type,
const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_type.size(), 2UL) << "The logical_right_shift op should has two inputs! Please check.";
CHECK_EQ(inputs_type[0], inputs_type[1])
<< "The data type of input tensors of logical_right_shift op should be equal, but here x:" << inputs_type[0]
<< " != y:" << inputs_type[1] << "! Please check.";
std::vector<Type> res{inputs_type[0]};
return res;
}

} // namespace op
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(logical_right_shift_ops) {
CINN_REGISTER_OP(logical_right_shift)
.describe("Logical Right Shift.")
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForLogicalRightShift)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForLogicalRightShift))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForLogicalRightShift))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise)
.set_support_level(4);

return true;
}
35 changes: 35 additions & 0 deletions cinn/hlir/op/contrib/logical_right_shift.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"

namespace cinn {
namespace hlir {
namespace op {

ir::Tensor LogicalRightShift(const ir::Tensor& A,
const ir::Tensor& B,
const Target& target,
const std::string& output_name);

} // namespace op
} // namespace hlir
} // namespace cinn
64 changes: 64 additions & 0 deletions cinn/hlir/op/contrib/logical_right_shift_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/logical_right_shift.h"

#include <glog/logging.h>
#include <gtest/gtest.h>

#include <string>
#include <vector>

#include "cinn/backends/codegen_c.h"
#include "cinn/backends/codegen_c_x86.h"
#include "cinn/backends/codegen_cuda_dev.h"
#include "cinn/common/context.h"
#include "cinn/lang/lower.h"
#include "cinn/lang/placeholder.h"
#include "cinn/poly/stage.h"

namespace cinn {
namespace hlir {
namespace op {

TEST(GenerateCode_Cpu, LogicalRightShift) {
common::Context::Global().ResetNameId();

common::Target target = common::DefaultHostTarget();
lang::Placeholder<int> x("x", std::vector<int>{10});
lang::Placeholder<int> y("y", std::vector<int>{10});
ir::Tensor res = LogicalRightShift(x, y, target, "test_logical_right_shift");

poly::StageMap stages = poly::CreateStages({res});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestGenerateCodeCpu_LogicalRightShift", stages, {res}, {}, {}, nullptr, target, true);

VLOG(6) << "Expr before CPU codegen:";
VLOG(6) << funcs[0]->body;

ir::Module::Builder builder("LogicalRightShift_Module", target);
for (auto& f : funcs) {
builder.AddFunction(f);
}

backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512);
codegen.SetInlineBuiltinCodes(false);
std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl);
VLOG(6) << "Cpu Codegen result:";
VLOG(6) << code << std::endl;
}

} // namespace op
} // namespace hlir
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/hlir/op/use_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ CINN_USE_REGISTER(cbrt_ops)
CINN_USE_REGISTER(clz_ops)
CINN_USE_REGISTER(popc_ops)
CINN_USE_REGISTER(reciprocal_ops)
CINN_USE_REGISTER(logical_right_shift_ops)
4 changes: 4 additions & 0 deletions cinn/runtime/cpu/host_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ inline int FN_INT32(clz)(int x) { return __builtin_clz(x); }

inline int FN_INT32(popc)(int x) { return __builtin_popcount(x); }

inline int FN_INT32(logical_right_shift)(int x, int y) { return (x >> y) & ~(((0x1 << 31) >> y) << 1); }

#undef FN_INT32

#define FN_INT64(func) cinn_host_##func##_int64
Expand Down Expand Up @@ -200,6 +202,8 @@ CINN_REGISTER_HELPER(host_intrinsics) {

REGISTER_EXTERN_FUNC_2_IN_1_INT32(pow)

REGISTER_EXTERN_FUNC_2_IN_1_INT32(logical_right_shift)

#undef REGISTER_EXTERN_FUNC_2_IN_1_INT32

REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_clz_int32, host_target, int, int);
Expand Down
2 changes: 2 additions & 0 deletions cinn/runtime/cpu/host_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ inline int FN_INT32(clz)(int x);

inline int FN_INT32(popc)(int x);

inline int FN_INT32(logical_right_shift)(int x, int y);

#undef FN_INT32

#define FN_INT64(func) cinn_host_##func##_int64
Expand Down
1 change: 1 addition & 0 deletions cinn/runtime/cuda/cinn_cuda_runtime_source.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ __device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; }
__device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; }
__device__ inline int FN_INT32(clz)(int a) { return __clz(a); }
__device__ inline int FN_INT32(popc)(int a) { return __popc(a); }
__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return (a >> b) & ~(((0x1 << 31) >> b) << 1); }

// *************************************************************** //

Expand Down
1 change: 1 addition & 0 deletions cinn/runtime/cuda/cuda_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ CINN_REGISTER_HELPER(cuda_intrinsics) {
REGISTER_EXTERN_FUNC_2_IN_1_INT32(bitwise_or)
REGISTER_EXTERN_FUNC_2_IN_1_INT32(bitwise_xor)
REGISTER_EXTERN_FUNC_2_IN_1_INT32(floor_divide)
REGISTER_EXTERN_FUNC_2_IN_1_INT32(logical_right_shift)

#undef REGISTER_EXTERN_FUNC_2_IN_1_INT32

Expand Down
Loading