Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Init performance by avoiding recursion analysis. #1500

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions domain_tests/arbitrary_domains_protobuf_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ TEST(ProtocolBuffer, RepeatedMutationEventuallyMutatesExtensionFields) {
},
Gt(0));
EXPECT_THAT(
GenerateNonUniqueValues(Arbitrary<TestProtobufWithExtension>(), 1, 5000),
GenerateNonUniqueValues(Arbitrary<TestProtobufWithExtension>(), 10, 100),
AllOf(Contains(has_ext), Contains(has_rep_ext)));
}

Expand Down Expand Up @@ -641,25 +641,28 @@ TEST(ProtocolBuffer, CountNumberOfFieldsCorrect) {
T v;
auto corpus_v_uninitialized = domain.FromValue(v);
EXPECT_TRUE(corpus_v_uninitialized != std::nullopt);
EXPECT_EQ(domain.CountNumberOfFields(corpus_v_uninitialized.value()), 26);
int fields_count = 28;
EXPECT_EQ(domain.CountNumberOfFields(corpus_v_uninitialized.value()),
fields_count + /*estimated subfields count*/ 4);
v.set_allocated_subproto(new SubT());
auto corpus_v_initizalize_one_optional_proto = domain.FromValue(v);
EXPECT_TRUE(corpus_v_initizalize_one_optional_proto != std::nullopt);
EXPECT_EQ(domain.CountNumberOfFields(
corpus_v_initizalize_one_optional_proto.value()),
28);
EXPECT_EQ(
domain.CountNumberOfFields(
corpus_v_initizalize_one_optional_proto.value()),
fields_count + /*subfields count*/ 2 + /*estimated subfields count*/ 2);
v.add_rep_subproto();
auto corpus_v_initizalize_one_repeated_proto_1 = domain.FromValue(v);
EXPECT_TRUE(corpus_v_initizalize_one_repeated_proto_1 != std::nullopt);
EXPECT_EQ(domain.CountNumberOfFields(
corpus_v_initizalize_one_repeated_proto_1.value()),
30);
fields_count + /*subfields count*/ 4);
v.add_rep_subproto();
auto corpus_v_initizalize_one_repeated_proto_2 = domain.FromValue(v);
EXPECT_TRUE(corpus_v_initizalize_one_repeated_proto_2 != std::nullopt);
EXPECT_EQ(domain.CountNumberOfFields(
corpus_v_initizalize_one_repeated_proto_2.value()),
32);
fields_count + /*subfields count*/ 6);
}

auto FieldNameHasSubstr(absl::string_view field_name) {
Expand Down
102 changes: 84 additions & 18 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ class ProtoPolicy {
return max;
}

void IncrementDepth() { ++depth_; }
void ResetDepth() { depth_ = 0; }
int depth() const { return depth_; }

private:
template <typename T>
struct FilterToValue {
Expand Down Expand Up @@ -345,6 +349,7 @@ class ProtoPolicy {
std::vector<FilterToValue<OptionalPolicy>> optional_policies_;
std::vector<FilterToValue<int64_t>> min_repeated_fields_sizes_;
std::vector<FilterToValue<int64_t>> max_repeated_fields_sizes_;
int depth_ = 0;

#define FUZZTEST_INTERNAL_POLICY_MEMBERS(Camel, cpp) \
private: \
Expand Down Expand Up @@ -453,6 +458,15 @@ class ProtobufDomainUntypedImpl
unset_oneof_fields_ = other.unset_oneof_fields_;
}

// We initialize to at most depth 5. This gives a good balanced between
// initialization speed and coverage.
bool IsTooDeep(absl::BitGenRef prng) {
constexpr int kMaxInitDepth = 5;
int acceptable_depth = absl::Uniform(absl::IntervalClosedClosed, prng,
int64_t{1}, kMaxInitDepth);
return policy_.depth() >= acceptable_depth;
}

corpus_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
FUZZTEST_INTERNAL_CHECK(
Expand All @@ -466,23 +480,21 @@ class ProtobufDomainUntypedImpl
for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
if (auto* oneof = field->containing_oneof()) {
if (!oneof_to_field.contains(oneof->index())) {
oneof_to_field[oneof->index()] = SelectAFieldIndexInOneof(
oneof, prng,
/*non_recursive_only=*/customized_fields_.empty());
oneof_to_field[oneof->index()] =
SelectAFieldIndexInOneof(oneof, prng);
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (!IsRequired(field) && customized_fields_.empty() &&
IsFieldRecursive(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
// fields, the initialization will be deterministic, which violates the
// assumption on domain Init. However, such cases should be extremely
// rare and breaking the assumption would not have severe consequences.
continue;
} else if (!IsRequired(field) && customized_fields_.empty()) {
// We avoid initializing non-required fields after some depth if they
// are not explicitly customized). Otherwise, the initialization may
// never terminate.
if (field->message_type() != nullptr && IsTooDeep(prng)) continue;
}
VisitProtobufField(field, InitializeVisitor{prng, *this, val});
}
// Depth is used only for initialization. When mutating, we should start
// from the beginning.
policy_.ResetDepth();
return val;
}

Expand Down Expand Up @@ -598,7 +610,10 @@ class ProtobufDomainUntypedImpl

if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
auto val_it = val.find(field->number());
if (val_it == val.end()) continue;
if (val_it == val.end()) {
total_weight += EstimatedMutationWeight(field);
continue;
}
if (field->is_repeated()) {
total_weight +=
GetSubDomain<ProtoMessageTag, true>(field).CountNumberOfFields(
Expand Down Expand Up @@ -635,7 +650,13 @@ class ProtobufDomainUntypedImpl

if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
auto val_it = val.find(field->number());
if (val_it == val.end()) continue;
if (val_it == val.end()) {
field_counter += EstimatedMutationWeight(field);
if (field_counter >= selected_field_index) continue;
VisitProtobufField(
field, MutateVisitor{prng, metadata, only_shrink, *this, val});
return field_counter;
}
if (field->is_repeated()) {
field_counter +=
GetSubDomain<ProtoMessageTag, true>(field).MutateSelectedField(
Expand Down Expand Up @@ -933,12 +954,13 @@ class ProtobufDomainUntypedImpl

template <typename OneofDescriptor>
int SelectAFieldIndexInOneof(const OneofDescriptor* oneof,
absl::BitGenRef prng, bool non_recursive_only) {
absl::BitGenRef prng) {
std::vector<int> fields;
for (int i = 0; i < oneof->field_count(); ++i) {
OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
if (policy == OptionalPolicy::kAlwaysNull) continue;
if (non_recursive_only && IsFieldRecursive(oneof->field(i))) continue;
// We avoid initializing non-required oneof fields after some depth.
if (policy == OptionalPolicy::kWithNull && IsTooDeep(prng)) continue;
fields.push_back(i);
}
if (fields.empty()) { // This can happen if all fields are unset.
Expand Down Expand Up @@ -1325,13 +1347,16 @@ class ProtobufDomainUntypedImpl
return descriptor->field_count() + extensions.size();
}

static auto GetProtobufFields(const Descriptor* descriptor) {
static auto GetProtobufFields(const Descriptor* descriptor,
bool include_extensions = true) {
std::vector<const FieldDescriptor*> fields;
fields.reserve(descriptor->field_count());
for (int i = 0; i < descriptor->field_count(); ++i) {
fields.push_back(descriptor->field(i));
}
descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
if (include_extensions) {
descriptor->file()->pool()->FindAllExtensions(descriptor, &fields);
}
return fields;
}

Expand Down Expand Up @@ -1599,6 +1624,7 @@ class ProtobufDomainUntypedImpl
field->message_type()),
use_lazy_initialization_);
result.SetPolicy(policy_);
result.GetPolicy().IncrementDepth();
return Domain<std::unique_ptr<Message>>(result);
} else {
return Domain<T>(ArbitraryImpl<T>());
Expand Down Expand Up @@ -1673,6 +1699,46 @@ class ProtobufDomainUntypedImpl
/*consider_non_terminating_recursions=*/false);
}

// Returns an estimate for the number of sub-fields of a given field to be
// used as a heuristic. For simplicity, the recursive sub-fields are ignored.
// And for efficiency, proto extensions as well as protos with large number of
// fields are ignored.
static int64_t EstimatedMutationWeight(const FieldDescriptor* field) {
auto descriptor = field->message_type();
if (!descriptor) return 1; // non-message fields
static absl::flat_hash_map<std::string, int> kFieldCountCache;
if (auto it = kFieldCountCache.find(descriptor->full_name());
it != kFieldCountCache.end()) {
return it->second;
}
absl::flat_hash_set<decltype(descriptor)> parents;
int result = EstimatedMutationWeight(descriptor, parents);
kFieldCountCache.insert({descriptor->full_name(), result});
return result;
}

template <typename Descriptor>
static int EstimatedMutationWeight(
const Descriptor* descriptor,
absl::flat_hash_set<const Descriptor*>& parents) {
int field_count = 0;
if (parents.contains(descriptor)) return field_count;
parents.insert(descriptor);
for (const FieldDescriptor* field :
GetProtobufFields(descriptor, /*include_extensions=*/false)) {
const auto* child = field->message_type();
field_count += child ? EstimatedMutationWeight(child, parents) : 1;
constexpr int kMaxFieldCount = 1000;
// Avoid expensive estimation operations.
if (field_count > kMaxFieldCount) {
field_count = kMaxFieldCount;
break;
}
}
parents.erase(descriptor);
return field_count;
}

bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
const ProtoPolicy<Message>& policy,
Expand Down
Loading