diff --git a/domain_tests/arbitrary_domains_protobuf_test.cc b/domain_tests/arbitrary_domains_protobuf_test.cc index 8ec6150e..d44d1e46 100644 --- a/domain_tests/arbitrary_domains_protobuf_test.cc +++ b/domain_tests/arbitrary_domains_protobuf_test.cc @@ -133,7 +133,7 @@ TEST(ProtocolBuffer, RepeatedMutationEventuallyMutatesExtensionFields) { }, Gt(0)); EXPECT_THAT( - GenerateNonUniqueValues(Arbitrary(), 1, 5000), + GenerateNonUniqueValues(Arbitrary(), 10, 100), AllOf(Contains(has_ext), Contains(has_rep_ext))); } @@ -641,25 +641,28 @@ TEST(ProtocolBuffer, CountNumberOfFieldsCorrect) { T v; auto corpus_v_uninitialized = domain.FromValue(v); EXPECT_TRUE(corpus_v_uninitialized != std::nullopt); - EXPECT_EQ(domain.CountNumberOfFields(corpus_v_uninitialized.value()), 26); + int fields_count = 28; + EXPECT_EQ(domain.CountNumberOfFields(corpus_v_uninitialized.value()), + fields_count + /*estimated subfields count*/ 4); v.set_allocated_subproto(new SubT()); auto corpus_v_initizalize_one_optional_proto = domain.FromValue(v); EXPECT_TRUE(corpus_v_initizalize_one_optional_proto != std::nullopt); - EXPECT_EQ(domain.CountNumberOfFields( - corpus_v_initizalize_one_optional_proto.value()), - 28); + EXPECT_EQ( + domain.CountNumberOfFields( + corpus_v_initizalize_one_optional_proto.value()), + fields_count + /*subfields count*/ 2 + /*estimated subfields count*/ 2); v.add_rep_subproto(); auto corpus_v_initizalize_one_repeated_proto_1 = domain.FromValue(v); EXPECT_TRUE(corpus_v_initizalize_one_repeated_proto_1 != std::nullopt); EXPECT_EQ(domain.CountNumberOfFields( corpus_v_initizalize_one_repeated_proto_1.value()), - 30); + fields_count + /*subfields count*/ 4); v.add_rep_subproto(); auto corpus_v_initizalize_one_repeated_proto_2 = domain.FromValue(v); EXPECT_TRUE(corpus_v_initizalize_one_repeated_proto_2 != std::nullopt); EXPECT_EQ(domain.CountNumberOfFields( corpus_v_initizalize_one_repeated_proto_2.value()), - 32); + fields_count + /*subfields count*/ 6); } auto FieldNameHasSubstr(absl::string_view field_name) { diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h index 279bf2dd..b79bcbc4 100644 --- a/fuzztest/internal/domains/protobuf_domain_impl.h +++ b/fuzztest/internal/domains/protobuf_domain_impl.h @@ -311,6 +311,10 @@ class ProtoPolicy { return max; } + void IncrementDepth() { ++depth_; } + void ResetDepth() { depth_ = 0; } + int depth() const { return depth_; } + private: template struct FilterToValue { @@ -345,6 +349,7 @@ class ProtoPolicy { std::vector> optional_policies_; std::vector> min_repeated_fields_sizes_; std::vector> max_repeated_fields_sizes_; + int depth_ = 0; #define FUZZTEST_INTERNAL_POLICY_MEMBERS(Camel, cpp) \ private: \ @@ -453,6 +458,15 @@ class ProtobufDomainUntypedImpl unset_oneof_fields_ = other.unset_oneof_fields_; } + // We initialize to at most depth 5. This gives a good balanced between + // initialization speed and coverage. + bool IsTooDeep(absl::BitGenRef prng) { + constexpr int kMaxInitDepth = 5; + int acceptable_depth = absl::Uniform(absl::IntervalClosedClosed, prng, + int64_t{1}, kMaxInitDepth); + return policy_.depth() >= acceptable_depth; + } + corpus_type Init(absl::BitGenRef prng) { if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed; FUZZTEST_INTERNAL_CHECK( @@ -466,23 +480,21 @@ class ProtobufDomainUntypedImpl for (const FieldDescriptor* field : GetProtobufFields(descriptor)) { if (auto* oneof = field->containing_oneof()) { if (!oneof_to_field.contains(oneof->index())) { - oneof_to_field[oneof->index()] = SelectAFieldIndexInOneof( - oneof, prng, - /*non_recursive_only=*/customized_fields_.empty()); + oneof_to_field[oneof->index()] = + SelectAFieldIndexInOneof(oneof, prng); } if (oneof_to_field[oneof->index()] != field->index()) continue; - } else if (!IsRequired(field) && customized_fields_.empty() && - IsFieldRecursive(field)) { - // We avoid initializing non-required recursive fields by default (if - // they are not explicitly customized). Otherwise, the initialization - // may never terminate. If a proto has only non-required recursive - // fields, the initialization will be deterministic, which violates the - // assumption on domain Init. However, such cases should be extremely - // rare and breaking the assumption would not have severe consequences. - continue; + } else if (!IsRequired(field) && customized_fields_.empty()) { + // We avoid initializing non-required fields after some depth if they + // are not explicitly customized). Otherwise, the initialization may + // never terminate. + if (field->message_type() != nullptr && IsTooDeep(prng)) continue; } VisitProtobufField(field, InitializeVisitor{prng, *this, val}); } + // Depth is used only for initialization. When mutating, we should start + // from the beginning. + policy_.ResetDepth(); return val; } @@ -598,7 +610,10 @@ class ProtobufDomainUntypedImpl if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { auto val_it = val.find(field->number()); - if (val_it == val.end()) continue; + if (val_it == val.end()) { + total_weight += EstimatedMutationWeight(field); + continue; + } if (field->is_repeated()) { total_weight += GetSubDomain(field).CountNumberOfFields( @@ -635,7 +650,13 @@ class ProtobufDomainUntypedImpl if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { auto val_it = val.find(field->number()); - if (val_it == val.end()) continue; + if (val_it == val.end()) { + field_counter += EstimatedMutationWeight(field); + if (field_counter >= selected_field_index) continue; + VisitProtobufField( + field, MutateVisitor{prng, metadata, only_shrink, *this, val}); + return field_counter; + } if (field->is_repeated()) { field_counter += GetSubDomain(field).MutateSelectedField( @@ -933,12 +954,13 @@ class ProtobufDomainUntypedImpl template int SelectAFieldIndexInOneof(const OneofDescriptor* oneof, - absl::BitGenRef prng, bool non_recursive_only) { + absl::BitGenRef prng) { std::vector fields; for (int i = 0; i < oneof->field_count(); ++i) { OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i)); if (policy == OptionalPolicy::kAlwaysNull) continue; - if (non_recursive_only && IsFieldRecursive(oneof->field(i))) continue; + // We avoid initializing non-required oneof fields after some depth. + if (policy == OptionalPolicy::kWithNull && IsTooDeep(prng)) continue; fields.push_back(i); } if (fields.empty()) { // This can happen if all fields are unset. @@ -1325,13 +1347,16 @@ class ProtobufDomainUntypedImpl return descriptor->field_count() + extensions.size(); } - static auto GetProtobufFields(const Descriptor* descriptor) { + static auto GetProtobufFields(const Descriptor* descriptor, + bool include_extensions = true) { std::vector fields; fields.reserve(descriptor->field_count()); for (int i = 0; i < descriptor->field_count(); ++i) { fields.push_back(descriptor->field(i)); } - descriptor->file()->pool()->FindAllExtensions(descriptor, &fields); + if (include_extensions) { + descriptor->file()->pool()->FindAllExtensions(descriptor, &fields); + } return fields; } @@ -1599,6 +1624,7 @@ class ProtobufDomainUntypedImpl field->message_type()), use_lazy_initialization_); result.SetPolicy(policy_); + result.GetPolicy().IncrementDepth(); return Domain>(result); } else { return Domain(ArbitraryImpl()); @@ -1673,6 +1699,46 @@ class ProtobufDomainUntypedImpl /*consider_non_terminating_recursions=*/false); } + // Returns an estimate for the number of sub-fields of a given field to be + // used as a heuristic. For simplicity, the recursive sub-fields are ignored. + // And for efficiency, proto extensions as well as protos with large number of + // fields are ignored. + static int64_t EstimatedMutationWeight(const FieldDescriptor* field) { + auto descriptor = field->message_type(); + if (!descriptor) return 1; // non-message fields + static absl::flat_hash_map kFieldCountCache; + if (auto it = kFieldCountCache.find(descriptor->full_name()); + it != kFieldCountCache.end()) { + return it->second; + } + absl::flat_hash_set parents; + int result = EstimatedMutationWeight(descriptor, parents); + kFieldCountCache.insert({descriptor->full_name(), result}); + return result; + } + + template + static int EstimatedMutationWeight( + const Descriptor* descriptor, + absl::flat_hash_set& parents) { + int field_count = 0; + if (parents.contains(descriptor)) return field_count; + parents.insert(descriptor); + for (const FieldDescriptor* field : + GetProtobufFields(descriptor, /*include_extensions=*/false)) { + const auto* child = field->message_type(); + field_count += child ? EstimatedMutationWeight(child, parents) : 1; + constexpr int kMaxFieldCount = 1000; + // Avoid expensive estimation operations. + if (field_count > kMaxFieldCount) { + field_count = kMaxFieldCount; + break; + } + } + parents.erase(descriptor); + return field_count; + } + bool IsOneofRecursive(const OneofDescriptor* oneof, absl::flat_hash_set& parents, const ProtoPolicy& policy,