Skip to content

Commit

Permalink
[converter] aten::{index_select, scalarImplicit} (#183)
Browse files Browse the repository at this point in the history
* [converter] aten::index_select

* [converter] aten::scalarImplicit
  • Loading branch information
peterjc123 authored Feb 21, 2023
1 parent fb477df commit f1fe949
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/op_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Operators that are implemented in Python
| `aten::hardtanh_` | |
| `aten::im2col` | only 4-D input tensors (batched image-like tensors) are supported |
| `aten::index` | Multiple indices for aten::index is not supported |
| `aten::index_select` | |
| `aten::instance_norm` | |
| `aten::layer_norm` | |
| `aten::le` | |
Expand Down Expand Up @@ -215,6 +216,7 @@ Non-tracking operators that are ignored during translation
| `aten::new_zeros` | |
| `aten::ones` | |
| `aten::ones_like` | |
| `aten::scalarImplicit` | |
| `aten::size` | |
| `aten::zeros` | |
| `aten::zeros_like` | |
30 changes: 30 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3295,6 +3295,36 @@ def model(x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_index_select(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

def model(x):
return x.index_select(1, torch.arange(2))

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_index_select_neg_dim(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

def model(x):
return x.index_select(-1, torch.arange(1, 5))

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_index(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

Expand Down
2 changes: 2 additions & 0 deletions tinynn/converter/operators/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
"aten::instance_norm": ATenInstanceNormOperator,
"aten::group_norm": ATenGroupNormOperator,
"aten::index": ATenIndexOperator,
"aten::index_select": ATenIndexSelectOperator,
"aten::clone": ATenCloneOperator,
"aten::repeat": ATenRepeatOperator,
"aten::hardswish": ATenHardswishOperator,
Expand Down Expand Up @@ -211,4 +212,5 @@
"aten::empty": NoTrackOperator,
"aten::new_zeros": NoTrackOperator,
"aten::new_ones": NoTrackOperator,
"aten::scalarImplicit": NoTrackOperator,
}
24 changes: 24 additions & 0 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2032,6 +2032,30 @@ def parse(self, node, attrs, args, graph_converter):
actual_input = actual_output


class ATenIndexSelectOperator(ATenIndexSelectSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
dim = self.input_tensors[1]
indices = self.input_tensors[2]

assert indices.dtype in (torch.int64, torch.int32)

input_tensor = self.find_or_create_input(0, graph_converter)

if dim < 0:
dim += len(input_tensor.shape)

new_indices = indices.to(dtype=torch.int32)
new_indices = new_indices + (new_indices < 0).int() * input_tensor.shape[dim]

indices_tensor = self.create_attr_tensor(new_indices)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)

graph_converter.add_operator(tfl.GatherOperator([input_tensor, indices_tensor], outputs, axis=dim))


class ATenLogSoftmaxOperator(ATenLogSoftmaxSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
Expand Down

0 comments on commit f1fe949

Please sign in to comment.