Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

add atan2 op #1058

Merged
merged 9 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ NETBUILDER_UNARY_OP_DEF(Abs, abs)
return BinaryOp(#op_type__, lhs, rhs, axis); \
}
NETBUILDER_BINARY_OP_DEF(Add, elementwise_add)
NETBUILDER_BINARY_OP_DEF(Atan2, elementwise_atan2)
NETBUILDER_BINARY_OP_DEF(Multiply, elementwise_mul)
NETBUILDER_BINARY_OP_DEF(Divide, divide)
NETBUILDER_BINARY_OP_DEF(Subtract, substract)
Expand Down
3 changes: 2 additions & 1 deletion cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ namespace frontend {
// Variable BINARY_OP(const Variable& lhs, const Variable& rhs, int axis = -1);
#define NETBUILDER_BINARY_OP_FOREACH(macro__) \
macro__(Add) \
macro__(Atan2) \
macro__(Subtract) \
macro__(Divide) \
macro__(Multiply) \
Expand Down Expand Up @@ -221,7 +222,7 @@ class NetBuilder {
#undef NETBUILDER_UNARY_OP_DECL

/**
* @brief Compute each each element in `lhs` variable and `rhs` variable in `axis` dimension, and return the result
* @brief Compute each element in `lhs` variable and `rhs` variable in `axis` dimension, and return the result
* Variable.
* @param lhs The left input variable.
* @param rhs The right input variable.
Expand Down
2 changes: 2 additions & 0 deletions cinn/hlir/op/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ std::vector<std::vector<std::string>> InferLayoutForIsClose(const std::vector<st
}

StrategyForBinary(elementwise_add, Add);
StrategyForBinary(elementwise_atan2, Atan2);
StrategyForBinary(elementwise_mul, Multiply);

StrategyForBinary(substract, Substract);
Expand Down Expand Up @@ -434,6 +435,7 @@ CINN_REGISTER_HELPER(broadcast_ops) {
.set_support_level(4);

CINN_REGISTER_BINARY(elementwise_add, Add);
CINN_REGISTER_BINARY(elementwise_atan2, Atan2);
CINN_REGISTER_BINARY(elementwise_mul, Multiply);

CINN_REGISTER_BINARY(substract, Substract);
Expand Down
16 changes: 16 additions & 0 deletions cinn/hlir/pe/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,22 @@ Tensor Pow(
return Broadcast(fn, A, B, output_name, axis);
}

#define PI 3.14159265357989
Tensor Atan2(const Tensor& A, const Tensor& B, const std::string& output_name, const Expr& axis) {
auto fn = [&](const Expr& elem_a, const Expr& elem_b) {
auto atan = lang::Atan(elem_a / elem_b);
auto pi = ir::Cast::Make(atan->type(), Expr(PI));
auto half_pi = ir::Cast::Make(atan->type(), Expr(PI / 2));
auto zero = ir::Cast::Make(atan->type(), Expr(0));
return ir::Select::Make(
ir::EQ::Make(elem_b, zero),
ir::Select::Make(ir::GT::Make(elem_a, zero), half_pi, -half_pi),
ir::Select::Make(
ir::GT::Make(elem_b, zero), atan, ir::Select::Make(ir::GE::Make(elem_a, zero), atan + pi, atan - pi)));
};
return Broadcast(fn, A, B, output_name, axis);
}

Tensor BroadcastTo(const Tensor& A,
const std::vector<int>& out_shape,
const std::vector<int>& broadcast_axes,
Expand Down
2 changes: 2 additions & 0 deletions cinn/hlir/pe/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ void GetBroadcastOutShape(const std::vector<int>& input_shape1,

//! Compute A + B with auto-broadcasting.
HLIR_DCL_BC_PE(Add);
//! Compute Atan2 with auto-broadcasting.
HLIR_DCL_BC_PE(Atan2);
//! Compute A - B with auto-broadcasting.
HLIR_DCL_BC_PE(Substract);
//! Compute A * B with auto-broadcasting.
Expand Down
22 changes: 22 additions & 0 deletions cinn/hlir/pe/pe_broadcast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <gtest/gtest.h>
#include <math.h>

#include "cinn/backends/llvm/execution_engine.h"
#include "cinn/cinn.h"
Expand Down Expand Up @@ -193,6 +194,27 @@ void TestBroadcastPE2(
TEST_BROADCAST_PE_FP32(Add, return a + b;)
TEST_BROADCAST_PE_FP32(Multiply, return a * b;)

#define PI 3.1415926535
float Atan2(float a, float b) {
if (b == 0.0) {
if (a > 0) {
return PI / 2;
} else {
return -PI / 2;
}
} else {
auto at = atan(a / b);
if (b > 0) {
return at;
} else if (a >= 0) {
return at + PI;
} else {
return at - PI;
}
}
}
TEST_BROADCAST_PE_FP32_BASIC(Atan2);

} // namespace pe
} // namespace hlir
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/pybind/pe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void BindPE(py::module* m) {
m->def(#name__, &hlir::pe::fn__, py::arg("x"), py::arg("y"), py::arg("out"), py::arg("axis") = Expr(-1))

BIND_BINARY(add, Add);
BIND_BINARY(atan2, Atan2);
BIND_BINARY(substract, Substract);
BIND_BINARY(multiply, Multiply);
BIND_BINARY(divide, Divide);
Expand Down
165 changes: 84 additions & 81 deletions cinn/runtime/cuda/cinn_cuda_runtime_source.cuh

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/tests/ops/test_binary_elementwise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_check_results(self):
"add", "subtract", "divide", "multiply", "floor_divide", "mod",
"floor_mod", "max", "min", "logical_and", "logical_or", "logical_xor",
"bitwise_and", "bitwise_or", "bitwise_xor", "equal", "not_equal",
"greater_than", "less_than", "greater_equal", "less_equal"
"greater_than", "less_than", "greater_equal", "less_equal", "atan2"
]

for op_name in test_op_list:
Expand Down