Skip to content

Commit

Permalink
[PT FE] Support non boolean inputs for __or__ and __and__ operations (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#19268)

* [PT FE] Support non boolean inputs for __or__ and __and__ operations

* Add test for __or__
  • Loading branch information
mvafin authored Aug 21, 2023
1 parent 3813b0b commit d55e45f
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 2 deletions.
42 changes: 42 additions & 0 deletions src/frontends/pytorch/src/op/logical.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/logical_or.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_or(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto y = context.get_input(1);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
// TODO: use bitwise op here when will be supported by openvino
auto or_node = context.mark_node(std::make_shared<v1::LogicalOr>(x, y));
return {or_node};
};

OutputVector translate_and(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto y = context.get_input(1);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
// TODO: use bitwise op here when will be supported by openvino
auto or_node = context.mark_node(std::make_shared<v1::LogicalAnd>(x, y));
return {or_node};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
6 changes: 4 additions & 2 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ OP_CONVERTER(translate_add);
OP_CONVERTER(translate_addcmul);
OP_CONVERTER(translate_addmm);
OP_CONVERTER(translate_all);
OP_CONVERTER(translate_and);
OP_CONVERTER(translate_arange);
OP_CONVERTER(translate_argmax);
OP_CONVERTER(translate_argsort);
Expand Down Expand Up @@ -111,6 +112,7 @@ OP_CONVERTER(translate_norm);
OP_CONVERTER(translate_numel);
OP_CONVERTER(translate_ones);
OP_CONVERTER(translate_ones_like);
OP_CONVERTER(translate_or);
OP_CONVERTER(translate_outer);
OP_CONVERTER(translate_pad);
OP_CONVERTER(translate_pairwise_distance);
Expand Down Expand Up @@ -202,11 +204,11 @@ OP_CONVERTER(translate_transpose_fx);
// Supported ops for TorchScript
const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
return {
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
{"aten::__and__", op::translate_and},
{"aten::__derive_index", op::translate_derive_index},
{"aten::__getitem__", op::translate_getitem},
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
{"aten::__or__", op::translate_1to1_match_2_inputs<opset10::LogicalOr>},
{"aten::__or__", op::translate_or},
{"aten::__range_length", op::translate_range_length},
{"aten::_convolution", op::translate_convolution},
{"aten::_convolution_mode", op::translate_convolution_mode},
Expand Down
28 changes: 28 additions & 0 deletions tests/layer_tests/pytorch_tests/test_or.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
from pytorch_layer_test_class import PytorchLayerTest


class TestLog(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randint(0, 255, (20, 30, 40, 50)),)

def create_model(self):
import torch

class aten_or(torch.nn.Module):
def forward(self, x):
res = torch.ByteTensor(x.size()).zero_()
res[:, :, :, 1:] = res[:, :, :, 1:] | (x[:, :, :, 1:] != x[:, :, :, :-1])
res[:, :, :, :-1] = res[:, :, :, :-1] | (x[:, :, :, 1:] != x[:, :, :, :-1])
return res.float()

return aten_or(), None, "aten::__or__"

@pytest.mark.nightly
@pytest.mark.precommit
def test_or(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, dynamic_shapes=False, trace_model=True)

0 comments on commit d55e45f

Please sign in to comment.