Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[converter] aten::{index_select, scalarImplicit} #183

Merged
merged 2 commits into from
Feb 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/op_matrix.md
Original file line number Diff line number Diff line change
@@ -90,6 +90,7 @@ Operators that are implemented in Python
| `aten::hardtanh_` | |
| `aten::im2col` | only 4-D input tensors (batched image-like tensors) are supported |
| `aten::index` | Multiple indices for aten::index is not supported |
| `aten::index_select` | |
| `aten::instance_norm` | |
| `aten::layer_norm` | |
| `aten::le` | |
@@ -215,6 +216,7 @@ Non-tracking operators that are ignored during translation
| `aten::new_zeros` | |
| `aten::ones` | |
| `aten::ones_like` | |
| `aten::scalarImplicit` | |
| `aten::size` | |
| `aten::zeros` | |
| `aten::zeros_like` | |
30 changes: 30 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
@@ -3295,6 +3295,36 @@ def model(x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_index_select(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

def model(x):
return x.index_select(1, torch.arange(2))

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_index_select_neg_dim(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

def model(x):
return x.index_select(-1, torch.arange(1, 5))

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_index(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

2 changes: 2 additions & 0 deletions tinynn/converter/operators/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -108,6 +108,7 @@
"aten::instance_norm": ATenInstanceNormOperator,
"aten::group_norm": ATenGroupNormOperator,
"aten::index": ATenIndexOperator,
"aten::index_select": ATenIndexSelectOperator,
"aten::clone": ATenCloneOperator,
"aten::repeat": ATenRepeatOperator,
"aten::hardswish": ATenHardswishOperator,
@@ -211,4 +212,5 @@
"aten::empty": NoTrackOperator,
"aten::new_zeros": NoTrackOperator,
"aten::new_ones": NoTrackOperator,
"aten::scalarImplicit": NoTrackOperator,
}
24 changes: 24 additions & 0 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
@@ -2032,6 +2032,30 @@ def parse(self, node, attrs, args, graph_converter):
actual_input = actual_output


class ATenIndexSelectOperator(ATenIndexSelectSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
dim = self.input_tensors[1]
indices = self.input_tensors[2]

assert indices.dtype in (torch.int64, torch.int32)

input_tensor = self.find_or_create_input(0, graph_converter)

if dim < 0:
dim += len(input_tensor.shape)

new_indices = indices.to(dtype=torch.int32)
new_indices = new_indices + (new_indices < 0).int() * input_tensor.shape[dim]

indices_tensor = self.create_attr_tensor(new_indices)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)

graph_converter.add_operator(tfl.GatherOperator([input_tensor, indices_tensor], outputs, axis=dim))


class ATenLogSoftmaxOperator(ATenLogSoftmaxSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)