Skip to content

Commit

Permalink
Fixed s2t handling of proto3 optional field without presence semantics.
Browse files Browse the repository at this point in the history
Note: this feature is not available in OSS yet, because the proto library s2t build against does not support it.
PiperOrigin-RevId: 369769351
  • Loading branch information
brills authored and tfx-copybara committed Apr 22, 2021
1 parent 072a4c5 commit 77fd1a7
Show file tree
Hide file tree
Showing 19 changed files with 352 additions and 135 deletions.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ sh_binary(
"//struct2tensor/test:test_any_py_pb2",
"//struct2tensor/test:test_extension_py_pb2",
"//struct2tensor/test:test_map_py_pb2",
"//struct2tensor/test:test_proto3_py_pb2",
"//struct2tensor/test:test_py_pb2",
],
)
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

## Bug Fixes and Other Changes

* Introduced DecodeProtoSparseV4. It is same as V3 and will replace V3 soon.

## Breaking Changes

## Deprecations
Expand Down
13 changes: 13 additions & 0 deletions struct2tensor/calculate_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,26 @@ class Options(object):
When a method takes an optional Options object but none is provided, it will
replace it with get_default_options() .
Available options:
ragged_checks: if True, add assertion ops when converting a Prensor object
to RaggedTensors.
sparse_checks: if True, add assertion ops when converting a Prensor object
to SparseTensors.
use_string_view: if True, decode sub-messages into string views to avoid
copying.
experimental_honor_proto3_optional_semantics: if True, if a proto3 primitive
optional field without the presence semantic (i.e. the field is without
the "optional" or "repeated" label) is requested to be parsed, it will
always have a value for each input parent message. If a value is not
present on wire, the default value (0 or "") will be used.
"""

def __init__(self, ragged_checks: bool, sparse_checks: bool):
"""Create options."""
self.ragged_checks = ragged_checks
self.sparse_checks = sparse_checks
self.use_string_view = False
self.experimental_honor_proto3_optional_semantics = False

def __str__(self):
return ("{ragged_checks:" + str(self.ragged_checks) + ", sparse_checks: " +
Expand Down
6 changes: 4 additions & 2 deletions struct2tensor/expression_impl/parse_message_level_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def parse_message_level_ex(
desc: descriptor.Descriptor,
field_names: Set[ProtoFieldName],
message_format: str = "binary",
backing_str_tensor: Optional[tf.Tensor] = None
backing_str_tensor: Optional[tf.Tensor] = None,
honor_proto3_optional_semantics: bool = False
) -> Mapping[StrStep, struct2tensor_ops._ParsedField]:
"""Parses regular fields, extensions, any casts, and map protos."""
raw_field_names = _get_field_names_to_parse(desc, field_names)
Expand All @@ -106,7 +107,8 @@ def parse_message_level_ex(
desc,
raw_field_names,
message_format=message_format,
backing_str_tensor=backing_str_tensor))
backing_str_tensor=backing_str_tensor,
honor_proto3_optional_semantics=honor_proto3_optional_semantics))
regular_field_map = {x.field_name: x for x in regular_fields}

any_fields = _get_any_parsed_fields(desc, regular_field_map, field_names)
Expand Down
59 changes: 33 additions & 26 deletions struct2tensor/expression_impl/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from struct2tensor.ops import struct2tensor_ops
import tensorflow as tf

from google.protobuf.descriptor_pb2 import FileDescriptorSet
from google.protobuf import descriptor_pb2
from google.protobuf import descriptor
from google.protobuf.descriptor_pool import DescriptorPool

Expand All @@ -52,7 +52,7 @@ def is_proto_expression(expr: expression.Expression) -> bool:
def create_expression_from_file_descriptor_set(
tensor_of_protos: tf.Tensor,
proto_name: ProtoFullName,
file_descriptor_set: FileDescriptorSet,
file_descriptor_set: descriptor_pb2.FileDescriptorSet,
message_format: str = "binary") -> expression.Expression:
"""Create an expression from a 1D tensor of serialized protos.
Expand Down Expand Up @@ -277,21 +277,21 @@ def calculate(
raise ValueError("Cannot find {} in {}".format(
str(self), str(parent_value)))
return self.calculate_from_parsed_field(parsed_field, destinations,
options.use_string_view)
options)
raise ValueError("Not a _ParentProtoNodeTensor: " + str(type(parent_value)))

@abc.abstractmethod
def calculate_from_parsed_field(self,
parsed_field: struct2tensor_ops._ParsedField,
destinations: Sequence[expression.Expression],
use_string_view: bool) -> prensor.NodeTensor:
def calculate_from_parsed_field(
self, parsed_field: struct2tensor_ops._ParsedField, # pylint: disable=protected-access
destinations: Sequence[expression.Expression],
options: calculate_options.Options) -> prensor.NodeTensor:
"""Calculate the NodeTensor given the parsed fields requested from a parent.
Args:
parsed_field: the parsed field from name_as_field.
destinations: the destination of the expression.
use_string_view: if true, enables string_views to be used for intermediate
serialized proto outputs.
options: calculate options.
Returns:
A node tensor for this node.
"""
Expand Down Expand Up @@ -320,10 +320,10 @@ def __init__(self, parent: "_ParentProtoExpression",
# TODO(martinz): make _get_dtype_from_cpp_type public.
self._field_descriptor = desc

def calculate_from_parsed_field(self,
parsed_field: struct2tensor_ops._ParsedField,
destinations: Sequence[expression.Expression],
use_string_view: bool) -> prensor.NodeTensor:
def calculate_from_parsed_field(
self, parsed_field: struct2tensor_ops._ParsedField, # pylint: disable=protected-access
destinations: Sequence[expression.Expression],
options: calculate_options.Options) -> prensor.NodeTensor:
return prensor.LeafNodeTensor(parsed_field.index, parsed_field.value,
self.is_repeated)

Expand Down Expand Up @@ -377,19 +377,21 @@ def __init__(self, parent: "_ParentProtoExpression",
backing_str_tensor)
self._desc = desc

def calculate_from_parsed_field(self,
parsed_field: struct2tensor_ops._ParsedField,
destinations: Sequence[expression.Expression],
use_string_view: bool) -> prensor.NodeTensor:
def calculate_from_parsed_field(
self, parsed_field: struct2tensor_ops._ParsedField, # pylint:disable=protected-access
destinations: Sequence[expression.Expression],
options: calculate_options.Options) -> prensor.NodeTensor:
needed_fields = _get_needed_fields(destinations)
backing_str_tensor = None
if use_string_view:
if options.use_string_view:
backing_str_tensor = self._backing_str_tensor
fields = parse_message_level_ex.parse_message_level_ex(
parsed_field.value,
self._desc,
needed_fields,
backing_str_tensor=backing_str_tensor)
backing_str_tensor=backing_str_tensor,
honor_proto3_optional_semantics=options
.experimental_honor_proto3_optional_semantics)
return _ProtoChildNodeTensor(parsed_field.index, self.is_repeated, fields)

def calculation_equal(self, expr: expression.Expression) -> bool:
Expand Down Expand Up @@ -429,21 +431,24 @@ def __init__(self, parent: "_ParentProtoExpression",
def transform_fn(self):
return self._transform_fn

def calculate_from_parsed_field(self,
parsed_field: struct2tensor_ops._ParsedField,
destinations: Sequence[expression.Expression],
use_string_view: bool) -> prensor.NodeTensor:
def calculate_from_parsed_field(
self,
parsed_field: struct2tensor_ops._ParsedField, # pylint:disable=protected-access
destinations: Sequence[expression.Expression],
options: calculate_options.Options) -> prensor.NodeTensor:
needed_fields = _get_needed_fields(destinations)
transformed_parent_indices, transformed_values = self._transform_fn(
parsed_field.index, parsed_field.value)
backing_str_tensor = None
if use_string_view:
if options.use_string_view:
backing_str_tensor = self._backing_str_tensor
fields = parse_message_level_ex.parse_message_level_ex(
transformed_values,
self._desc,
needed_fields,
backing_str_tensor=backing_str_tensor)
backing_str_tensor=backing_str_tensor,
honor_proto3_optional_semantics=options
.experimental_honor_proto3_optional_semantics)
return _ProtoChildNodeTensor(transformed_parent_indices, self.is_repeated,
fields)

Expand Down Expand Up @@ -517,7 +522,9 @@ def calculate(
self._descriptor,
needed_fields,
message_format=self._message_format,
backing_str_tensor=backing_str_tensor)
backing_str_tensor=backing_str_tensor,
honor_proto3_optional_semantics=options
.experimental_honor_proto3_optional_semantics)
return _ProtoRootNodeTensor(size, fields)

def calculation_is_identity(self) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions struct2tensor/expression_impl/proto_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from struct2tensor.test import test_extension_pb2
from struct2tensor.test import test_map_pb2
from struct2tensor.test import test_pb2
from struct2tensor.test import test_proto3_pb2
import tensorflow as tf

from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import
Expand Down Expand Up @@ -456,6 +457,7 @@ def test_transformed_field_values_with_multiple_transforms(
self._check_string_view()



def _reverse_values(parent_indices, values):
"""A simple function for testing create_transformed_field."""
return parent_indices, tf.reverse(values, axis=[-1])
Expand Down
Loading

0 comments on commit 77fd1a7

Please sign in to comment.