Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 284601545
  • Loading branch information
tfx-copybara authored and mzinkevi committed Dec 10, 2019
1 parent 46c4f48 commit 940ac11
Show file tree
Hide file tree
Showing 22 changed files with 499 additions and 261 deletions.
234 changes: 149 additions & 85 deletions examples/prensor_playground.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions struct2tensor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# simply and efficiently.

package(default_visibility = [
"//analysis/dremel/core/tabledefs/tensorflow:__subpackages__",
"//learning/tfx/autotfx:__subpackages__",
"//learning/tfx/users/tfx/util/tf_ranking:__subpackages__",
"//nlp/nlx/infrastructure/multiscale:__subpackages__",
"//third_party/py/tfx_bsl:__subpackages__",
"//third_party/tensorflow_ranking/google:__subpackages__",
"//video/youtube/discovery/tensorflow/python/input/feature_parser:__subpackages__",
"@//:__subpackages__",
])
Expand Down
27 changes: 25 additions & 2 deletions struct2tensor/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,20 @@ def apply(self,
def apply_schema(self, schema):
return apply_schema.apply_schema(self, schema)

def get_paths_with_schema(self):
"""Extract only paths that contain schema information."""
result = []
for name, child in self.get_known_children().items():
if child.schema_feature is None:
continue
result.extend(
[path.Path([name]).concat(x) for x in child.get_paths_with_schema()])
# Note: We always take the root path and so will return an empty schema
# if there is no schema information on any nodes, including the root.
if not result:
result.append(path.Path([]))
return result

def _populate_schema_feature_children(self, feature_list):
"""Populate a feature list from the children of this node.
Expand All @@ -535,8 +549,17 @@ def _populate_schema_feature_children(self, feature_list):
new_feature.struct_domain.feature)
new_feature.name = name

def get_schema(self):
"""Returns a schema for the entire tree."""
def get_schema(self, create_schema_features=True):
"""Returns a schema for the entire tree.
Args:
create_schema_features: If True, schema features are added for all
children and a schema entry is created if not available on the child. If
False, features are left off of the returned schema if there is no
schema_feature on the child.
"""
if not create_schema_features:
return self.project(self.get_paths_with_schema()).get_schema()
result = schema_pb2.Schema()
self._populate_schema_feature_children(result.feature)
return result
Expand Down
19 changes: 12 additions & 7 deletions struct2tensor/expression_impl/parse_message_level_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from struct2tensor import path
from struct2tensor.ops import struct2tensor_ops
import tensorflow as tf
from typing import List, Mapping, Optional, Sequence, Set
from typing import List, Mapping, Optional, Sequence, Set, Text

from google.protobuf import descriptor

Expand All @@ -89,15 +89,20 @@
StrStep = str # pylint: disable=g-ambiguous-str-annotation


def parse_message_level_ex(tensor_of_protos,
desc,
field_names
):
def parse_message_level_ex(
tensor_of_protos,
desc,
field_names,
message_format = "binary"
):
"""Parses regular fields, extensions, any casts, and map protos."""
raw_field_names = _get_field_names_to_parse(desc, field_names)
regular_fields = list(
struct2tensor_ops.parse_message_level(tensor_of_protos, desc,
raw_field_names))
struct2tensor_ops.parse_message_level(
tensor_of_protos,
desc,
raw_field_names,
message_format=message_format))
regular_field_map = {x.field_name: x for x in regular_fields}

any_fields = _get_any_parsed_fields(desc, regular_field_map, field_names)
Expand Down
31 changes: 23 additions & 8 deletions struct2tensor/expression_impl/parse_message_level_ex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
from __future__ import print_function

from absl.testing import absltest
from absl.testing import parameterized
from struct2tensor.expression_impl import parse_message_level_ex
from struct2tensor.test import test_any_pb2
from struct2tensor.test import test_map_pb2
from struct2tensor.test import test_pb2
import tensorflow as tf

from google.protobuf import text_format
from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import

_INDEX = "index"
Expand All @@ -39,11 +41,18 @@
_USERINFO_NO_PARENS = b"type.googleapis.com/struct2tensor.test.UserInfo"
_USERINFO = "(type.googleapis.com/struct2tensor.test.UserInfo)"

_MESSAGE_FORMATS = ["binary", "text"]

def _run_parse_message_level_ex(proto_list, fields):
serialized = [x.SerializeToString() for x in proto_list]

def _run_parse_message_level_ex(proto_list, fields, message_format="binary"):
if message_format == "text":
serialized = [text_format.MessageToString(x) for x in proto_list]
elif message_format == "binary":
serialized = [x.SerializeToString() for x in proto_list]
else:
raise ValueError('Message format must be one of "text", "binary"')
parsed_field_dict = parse_message_level_ex.parse_message_level_ex(
tf.constant(serialized), proto_list[0].DESCRIPTOR, fields)
tf.constant(serialized), proto_list[0].DESCRIPTOR, fields, message_format)
sess_input = {}
for key, value in parsed_field_dict.items():
local_dict = {}
Expand Down Expand Up @@ -82,8 +91,10 @@ def _get_empty_all_simple():


@test_util.run_all_in_graph_and_eager_modes
class ParseMessageLevelExTest(tf.test.TestCase):
class ParseMessageLevelExTest(parameterized.TestCase, tf.test.TestCase):

# TODO(askerryryan): Consider supporting Any types for text format. Currently
# only binary format is supported.
def test_any_field(self):
original_protos = _create_any_protos()
result = _run_parse_message_level_ex(original_protos, {_ALLSIMPLE})
Expand Down Expand Up @@ -138,7 +149,9 @@ def test_full_name_from_any_step(self):
self.assertIsNone(
parse_message_level_ex.get_full_name_from_any_step("broken)"))

def test_normal_field(self):
@parameterized.named_parameters(
[dict(testcase_name=f, message_format=f) for f in _MESSAGE_FORMATS])
def test_normal_field(self, message_format):
"""Test three messages with a repeated string."""
all_simple = test_pb2.AllSimple()
all_simple.repeated_string.append("foo")
Expand All @@ -147,21 +160,23 @@ def test_normal_field(self):

result = _run_parse_message_level_ex(
[all_simple, all_simple_empty, all_simple, all_simple],
{"repeated_string"})
{"repeated_string"}, message_format)
self.assertNotIn("repeated_bool", result)

self.assertAllEqual(result["repeated_string"][_INDEX], [0, 0, 2, 2, 3, 3])
self.assertAllEqual(result["repeated_string"][_VALUE],
[b"foo", b"foo2", b"foo", b"foo2", b"foo", b"foo2"])

def test_bool_key_type(self):
@parameterized.named_parameters(
[dict(testcase_name=f, message_format=f) for f in _MESSAGE_FORMATS])
def test_bool_key_type(self, message_format):
map_field = "bool_string_map[1]"
message_with_map_0 = test_map_pb2.MessageWithMap()
message_with_map_0.bool_string_map[False] = "hello"
message_with_map_1 = test_map_pb2.MessageWithMap()
message_with_map_1.bool_string_map[True] = "goodbye"
result = _run_parse_message_level_ex(
[message_with_map_0, message_with_map_1], {map_field})
[message_with_map_0, message_with_map_1], {map_field}, message_format)
self.assertIn(map_field, result)
self.assertAllEqual(result[map_field][_VALUE], [b"goodbye"])
self.assertAllEqual(result[map_field][_INDEX], [1])
Expand Down
3 changes: 2 additions & 1 deletion struct2tensor/expression_impl/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class _ProjectExpression(expression.Expression):
"""Project all subfields of an expression."""

def __init__(self, origin, paths):
super(_ProjectExpression, self).__init__(origin.is_repeated, origin.type)
super(_ProjectExpression, self).__init__(origin.is_repeated, origin.type,
origin.schema_feature)
self._paths_map = _group_paths_by_first_step(paths)
self._origin = origin

Expand Down
32 changes: 24 additions & 8 deletions struct2tensor/expression_impl/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from struct2tensor.expression_impl import parse_message_level_ex
from struct2tensor.ops import struct2tensor_ops
import tensorflow as tf
from typing import FrozenSet, Mapping, Optional, Sequence, Set, Tuple, Union
from typing import FrozenSet, Mapping, Optional, Sequence, Set, Text, Tuple, Union


from google.protobuf.descriptor_pb2 import FileDescriptorSet
Expand All @@ -55,8 +55,10 @@ def is_proto_expression(expr):


def create_expression_from_file_descriptor_set(
tensor_of_protos, proto_name,
file_descriptor_set):
tensor_of_protos,
proto_name,
file_descriptor_set,
message_format = "binary"):
"""Create an expression from a 1D tensor of serialized protos.
Args:
Expand All @@ -67,6 +69,8 @@ def create_expression_from_file_descriptor_set(
and all its dependencies' FileDescriptorProto. Note that if file1 imports
file2, then file2's FileDescriptorProto must precede file1's in
file_descriptor_set.file.
message_format: Indicates the format of the protocol buffer: is one of
'text' or 'binary'.
Returns:
An expression.
Expand All @@ -80,22 +84,25 @@ def create_expression_from_file_descriptor_set(
# This method raises if proto not found.
desc = pool.FindMessageTypeByName(proto_name)

return create_expression_from_proto(tensor_of_protos, desc)
return create_expression_from_proto(tensor_of_protos, desc, message_format)


def create_expression_from_proto(
tensor_of_protos,
desc):
desc,
message_format = "binary"):
"""Create an expression from a 1D tensor of serialized protos.
Args:
tensor_of_protos: 1D tensor of serialized protos.
desc: a descriptor of protos in tensor of protos.
message_format: Indicates the format of the protocol buffer: is one of
'text' or 'binary'.
Returns:
An expression.
"""
return _ProtoRootExpression(desc, tensor_of_protos)
return _ProtoRootExpression(desc, tensor_of_protos, message_format)


class _ProtoRootNodeTensor(prensor.RootNodeTensor):
Expand Down Expand Up @@ -309,16 +316,22 @@ class _ProtoRootExpression(expression.Expression):
_ProtoChildExpression and _ProtoLeafExpression to consume.
"""

def __init__(self, desc, tensor_of_protos):
def __init__(self,
desc,
tensor_of_protos,
message_format = "binary"):
"""Initialize a proto expression.
Args:
desc: the descriptor of the expression.
tensor_of_protos: a 1-D tensor to get the protos from.
message_format: Indicates the format of the protocol buffer: is one of
'text' or 'binary'.
"""
super(_ProtoRootExpression, self).__init__(True, None)
self._descriptor = desc
self._tensor_of_protos = tensor_of_protos
self._message_format = message_format

def get_path(self):
"""Returns the path to the root of the proto."""
Expand All @@ -342,7 +355,10 @@ def calculate(
size = tf.size(self._tensor_of_protos, out_type=tf.int64)
needed_fields = _get_needed_fields(destinations)
fields = parse_message_level_ex.parse_message_level_ex(
self._tensor_of_protos, self._descriptor, needed_fields)
self._tensor_of_protos,
self._descriptor,
needed_fields,
message_format=self._message_format)
return _ProtoRootNodeTensor(size, fields)

def calculation_is_identity(self):
Expand Down
60 changes: 60 additions & 0 deletions struct2tensor/expression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,66 @@ def test_get_schema(self):
self.assertEqual(doc_feature_map["bar"].presence.min_count, 17)
self.assertIn("keep_me", doc_feature_map)

def test_get_schema_missing_features(self):
# The expr has a number of features: foo, foorepeated, doc, user.
expr = create_expression.create_expression_from_prensor(
prensor_test_util.create_big_prensor())
# The schema has only a subset of the features on the expr.
schema = schema_pb2.Schema()
feature = schema.feature.add()
feature.name = "foo"
feature.type = schema_pb2.FeatureType.INT
feature.value_count.min = 1
feature.value_count.max = 1
feature = schema.feature.add()
feature.name = "foorepeated"
feature.type = schema_pb2.FeatureType.INT
feature.value_count.min = 0
feature.value_count.max = 5
feature = schema.feature.add()
feature.name = "doc"
feature.type = schema_pb2.FeatureType.STRUCT
feature.struct_domain.feature.append(
schema_pb2.Feature(name="keep_me", type=schema_pb2.FeatureType.INT))

# By default, the output schema has all features present in the expr.
expr = expr.apply_schema(schema)
output_schema = expr.get_schema()
self.assertNotEqual(schema, output_schema)
self.assertLen(schema.feature, 3)
self.assertLen(output_schema.feature, 4)

# With create_schema_features = False, only features on the original schema
# propogate to the new schema.
output_schema = expr.get_schema(create_schema_features=False)
self.assertLen(output_schema.feature, 3)

def test_get_schema_empty_schema(self):
expr = create_expression.create_expression_from_prensor(
prensor_test_util.create_big_prensor())
schema = schema_pb2.Schema()
# By default, the output schema has all features present in the proto.
expr = expr.apply_schema(schema)
output_schema = expr.get_schema()
self.assertEmpty(schema.feature)
self.assertLen(output_schema.feature, 4)

# With create_schema_features = False, the schema will be empty.
output_schema = expr.get_schema(create_schema_features=False)
self.assertEmpty(output_schema.feature)
self.assertEqual(schema, output_schema)

def test_get_schema_no_schema(self):
expr = create_expression.create_expression_from_prensor(
prensor_test_util.create_big_prensor())
output_schema = expr.get_schema()
self.assertLen(output_schema.feature, 4)

# With create_schema_features = False, the schema will be empty.
output_schema = expr.get_schema(create_schema_features=False)
self.assertEmpty(output_schema.feature)
self.assertEqual(schema_pb2.Schema(), output_schema)


@test_util.run_all_in_graph_and_eager_modes
class ExpressionValuesTest(tf.test.TestCase):
Expand Down
9 changes: 5 additions & 4 deletions struct2tensor/kernels/decode_proto_map_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ using ::tensorflow::Status;
using ::tensorflow::Tensor;
using ::tensorflow::TensorShape;
using ::tensorflow::TensorShapeUtils;
using ::tensorflow::tstring;
namespace errors = ::tensorflow::errors;

constexpr int kKeyFieldNumber = 1;
Expand Down Expand Up @@ -124,7 +125,7 @@ struct FieldTypeTraits {};
using TensorCppType = \
typename tensorflow::EnumToDataType<kTFDataType>::Type; \
static_assert(sizeof(TensorCppType) == sizeof(FieldCppType) || \
(std::is_same<TensorCppType, std::string>::value && \
(std::is_same<TensorCppType, tstring>::value && \
std::is_same<FieldCppType, absl::string_view>::value), \
"Unexpected FIELD_CPP_TYPE and TENSOR_DTYPE_ENUM pair"); \
};
Expand Down Expand Up @@ -392,13 +393,13 @@ class MapEntryCollector {
~MapEntryCollector() {}

Status ConsumeAndPopulateOutputTensors(
absl::Span<const std::string> serialized_protos,
absl::Span<const tstring> serialized_protos,
absl::Span<const tensorflow::int64> parent_indices,
OpKernelContext* op_kernel_contxt) const {
std::unique_ptr<ValueCollectorBase> value_collector;
TF_RETURN_IF_ERROR(MakeValueCollector(num_keys_, &value_collector));
for (int i = 0; i < serialized_protos.size(); ++i) {
const std::string& p = serialized_protos[i];
const tstring& p = serialized_protos[i];
const int64_t parent_index = parent_indices[i];
StreamingProtoReader reader(p);
bool key_field_found = false;
Expand Down Expand Up @@ -655,7 +656,7 @@ class DecodeProtoMapOp : public OpKernel {
OP_REQUIRES_OK(
context,
map_entry_collector_->ConsumeAndPopulateOutputTensors(
absl::MakeConstSpan(serialized_protos_tensor.flat<std::string>().data(),
absl::MakeConstSpan(serialized_protos_tensor.flat<tstring>().data(),
num_protos),
absl::MakeConstSpan(
parent_indices_tensor.flat<tensorflow::int64>().data(),
Expand Down
Loading

0 comments on commit 940ac11

Please sign in to comment.