Skip to content

Commit

Permalink
[converter] add map_bilstm_to_lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 committed Jul 27, 2022
1 parent f20ba9b commit 8d84842
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 10 deletions.
3 changes: 3 additions & 0 deletions docs/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ Generally, for a vision model, the memory layout of the input data used by PyTor
#### Why the converted model with grouped (de)convolution does not work?
Since TFLite does not officially support grouped (de)convolution, we have extended the implementation of grouped (de)convolution internally based on the `CONV_2D` and the `TRANSPOSE_CONV` operator. To generate a standard TFLite model, you can add the parameter `group_conv_rewrite=True` when defining TFLiteConverter.

#### What if `BidirectionalLSTM` is unsupported while `UnidirectionalLSTM` is supported?
You may add the parameter `map_bilstm_to_lstm=True` when defining TFLiteConverter.

#### How to convert a model with LSTM?
Since the target format of our conversion is TFLite, we need to understand how LSTM works in PyTorch and Tensorflow respectively.

Expand Down
3 changes: 3 additions & 0 deletions docs/FAQ_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ export_converter_files(model, dummy_input, export_dir, export_name)
#### 为何有分组(反)卷积的模型转换出来无法运行?
由于TFLite官方无分组(反)卷积的支持,我们在内部基于`CONV_2D``TRANSPOSE_CONV`算子拓展了分组(反)卷积的实现。如需生成标准的TFLite模型,可以在定义TFLiteConverter时加上`group_conv_rewrite=True`这个参数。

#### 如果我的部署平台只支持`UnidirectionalLSTM`,不支持`BidirectionalLSTM`怎么办?
可以在定义TFLiteConverter时加上`map_bilstm_to_lstm=True`这个参数。

#### 如何转换带LSTM的模型?
由于我们转换的目标为TFLite,因此需要先了解一下在PyTorch和Tensorflow中LSTM分别是如何运行的。

Expand Down
120 changes: 120 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2621,6 +2621,50 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bilstm_multi_layer_as_lstm(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.lstm = nn.LSTM(10, 20, 2, bidirectional=True)

def forward(self, x):
return self.lstm(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False, map_bilstm_to_lstm=True)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bilstm_multi_layer_no_bias_as_lstm(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.lstm = nn.LSTM(10, 20, 2, bidirectional=True, bias=False)

def forward(self, x):
return self.lstm(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False, map_bilstm_to_lstm=True)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_sigmoid_(self):
dummy_input = torch.randn(1, 3, 224, 224, dtype=torch.float32)

Expand Down Expand Up @@ -5906,6 +5950,82 @@ def forward(self, x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, atol=256.0, rtol=256.0)

@unittest.skipIf(not hasattr(torch.nn.quantized.dynamic, 'LSTM'), 'Quantized lstm is not supported')
def test_bilstm_dynamic_as_lstm(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.lstm = torch.nn.quantized.dynamic.LSTM(10, 20, bidirectional=True)

def forward(self, x):
return self.lstm(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, quantize_target_type='int8', map_bilstm_to_lstm=True
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, atol=256.0, rtol=256.0)

@unittest.skipIf(not hasattr(torch.nn.quantized.dynamic, 'LSTM'), 'Quantized lstm is not supported')
def test_bilstm_dynamic_batch_first_as_lstm(self):
raise unittest.SkipTest('TFLite hybrid LSTM kernel with batch_first=True is broken')
dummy_input = torch.randn(1, 9, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.lstm = torch.nn.quantized.dynamic.LSTM(10, 20, batch_first=True, bidirectional=True)

def forward(self, x):
return self.lstm(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, quantize_target_type='int8', map_bilstm_to_lstm=True
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, check_stride=False, atol=256.0, rtol=256.0)

@unittest.skipIf(not hasattr(torch.nn.quantized.dynamic, 'LSTM'), 'Quantized lstm is not supported')
def test_bilstm_dynamic_multi_layer_as_lstm(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.lstm = torch.nn.quantized.dynamic.LSTM(10, 20, 2, bidirectional=True)

def forward(self, x):
return self.lstm(x)[0]

model = Model()
model.eval()

model_path = get_model_path()
converter = TFLiteConverter(
model, dummy_input, model_path, nchw_transpose=False, quantize_target_type='int8', map_bilstm_to_lstm=True
)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, atol=256.0, rtol=256.0)


if __name__ == '__main__':
unittest.main()
11 changes: 10 additions & 1 deletion tinynn/converter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
group_conv_rewrite: bool = False,
rewrite_quantizable: bool = False,
tflite_micro_rewrite: bool = False,
map_bilstm_to_lstm: bool = False,
) -> None:
""" The TFLiteConverter class
Expand Down Expand Up @@ -76,6 +77,8 @@ def __init__(
rewrite_quantizable (bool): Rewriting quantizable ops (e.g. BATCH_MATMUL, SOFTMAX, LOG_SOFTMAX) \
to use quantized kernels. Defaults to False
tflite_micro_rewrite (bool): Rewriting for running on TFLite-micro. Defaults to False
map_bilstm_to_lstm (bool): Translating bidirectional LSTM to TFLite ops with `UnidirectionalLSTM`. \
Defaults to False
"""

self.model = model
Expand Down Expand Up @@ -115,6 +118,7 @@ def __init__(
self.group_conv_rewrite = group_conv_rewrite
self.rewrite_quantizable = rewrite_quantizable
self.tflite_micro_rewrite = tflite_micro_rewrite
self.map_bilstm_to_lstm = map_bilstm_to_lstm

if quantize_target_type == 'uint8':
self.q_type = np.uint8
Expand Down Expand Up @@ -333,7 +337,12 @@ def init_operations(self):

converter_type = OPERATOR_CONVERTER_DICT.get(k, NoTrackOperator)
converter = converter_type(
node, self.tensor_map, not self.strict_symmetric_check, self.q_type, self.hybrid_q_type
node,
self.tensor_map,
not self.strict_symmetric_check,
self.q_type,
self.hybrid_q_type,
self.map_bilstm_to_lstm,
)
# Don't track the operator if all the input nodes are not tracked unless it has custom implementation
# (e.g prim::* ops)
Expand Down
69 changes: 61 additions & 8 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,68 @@ def parse_common(
outputs = [layer_output]

if bidirectional:
ops.append(
tfl.BidirectionalSequenceLstmOperator(
inputs,
outputs,
fusedActivationFunction=tfl_schema.ActivationFunctionType.TANH,
timeMajor=not batch_first,
mergeOutputs=True,
if not self.map_bilstm_to_lstm:
ops.append(
tfl.BidirectionalSequenceLstmOperator(
inputs,
outputs,
fusedActivationFunction=tfl_schema.ActivationFunctionType.TANH,
timeMajor=not batch_first,
mergeOutputs=True,
)
)
)
else:
fw_i_end = input_start_indices[-1]
fw_s_start = state_start_index
fw_s_end = state_start_index + len(state_kinds)
fw_pad = num_input_tensors // 2 - fw_s_end
fw_lstm_inputs = (
inputs[:fw_i_end] + inputs[fw_s_start:fw_s_end] + [tfl.OptionalTensorInstance] * fw_pad
)
fw_out, bw_out = [
self.create_transform_tensor(t, quantization=outputs[0].quantization)
for t in np.split(outputs[0].tensor, 2, -1)
]

ops.append(
tfl.UnidirectionalSequenceLstmOperator(
fw_lstm_inputs,
[fw_out],
fusedActivationFunction=tfl_schema.ActivationFunctionType.TANH,
timeMajor=not batch_first,
)
)

time_dim = 1 if batch_first else 0
bw_in = self.create_transform_tensor(np.flip(current_input.tensor, time_dim))
bw_dim = self.create_attr_tensor(np.array([time_dim], dtype='int32'))
ops.append(tfl.ReverseV2Operator([current_input, bw_dim], [bw_in]))

bw_raw_out = self.create_transform_tensor(np.flip(bw_out.tensor, time_dim))
bw_o_start = input_start_indices[-1]
bw_o_end = state_start_index
bw_s_start = state_start_index + len(state_kinds)
bw_s_end = state_start_index + len(state_kinds) * num_directions
bw_pad = num_input_tensors // 2 - bw_s_end
bw_lstm_inputs = (
[bw_in]
+ inputs[bw_o_start:bw_o_end]
+ inputs[bw_s_start:bw_s_end]
+ [tfl.OptionalTensorInstance] * bw_pad
)

ops.append(
tfl.UnidirectionalSequenceLstmOperator(
bw_lstm_inputs,
[bw_raw_out],
fusedActivationFunction=tfl_schema.ActivationFunctionType.TANH,
timeMajor=not batch_first,
)
)

ops.append(tfl.ReverseV2Operator([bw_raw_out, bw_dim], [bw_out]))

ops.append(tfl.ConcatenationOperator([fw_out, bw_out], outputs, axis=2))
else:
ops.append(
tfl.UnidirectionalSequenceLstmOperator(
Expand Down
5 changes: 4 additions & 1 deletion tinynn/converter/operators/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@


class OperatorConverter(ABC):
def __init__(self, node, tensor_map, asymmetric=True, q_type=np.uint8, hybrid_q_type=np.int8) -> None:
def __init__(
self, node, tensor_map, asymmetric=True, q_type=np.uint8, hybrid_q_type=np.int8, map_bilstm_to_lstm=False
) -> None:
self.input_names = self.get_input_names(node)
self.output_names = self.get_output_names(node)
self.input_tensors = self.get_input_tensors(tensor_map)
Expand All @@ -30,6 +32,7 @@ def __init__(self, node, tensor_map, asymmetric=True, q_type=np.uint8, hybrid_q_
self.asymmetric = asymmetric
self.q_type = q_type
self.hybrid_q_type = hybrid_q_type
self.map_bilstm_to_lstm = map_bilstm_to_lstm

@abstractmethod
def parse(self, node, attrs, args, graph_converter):
Expand Down

0 comments on commit 8d84842

Please sign in to comment.