From ccac700789e38b1ffbde8474ce50edde4efe9c0f Mon Sep 17 00:00:00 2001 From: Hadi Ravanbakhsh Date: Mon, 3 Feb 2025 14:36:06 -0800 Subject: [PATCH] Initialize recursive protobuf fields more efficiently. Currently, a field (F0) is not initialized if there are subfields of the form F0 -> F1 -> ... -> Fs -> ... -> Fn -> Fs. This means that for a huge protobuf field that has a recursive subfield deep in its definition, the whole field is not initialized. And later, even when F0 is initialized, F1 won't get initialized, etc. This could be very inefficient. To avoid this, we define "recursion breaker fields". For example, "Fs" becomes a recursion breaker. Then, all fields up to Fs are initialized. And later when Fs gets initialized, all Fs -> ... -> Fn get initialized. This CL consists of the following changes: - IsProtoRecursive deals with infinite recursions only. - IsFinitelyRecursive is replaced with IsRecursionBreaker which is implemented separately. - Having separate strategy for handling of infinite recursion and finite recursion (through recursion breakers), we need to explicitely avoid initialization and mutation of infinitely recursive fields. PiperOrigin-RevId: 722802824 --- .../arbitrary_domains_protobuf_test.cc | 16 ++ fuzztest/BUILD | 1 + fuzztest/CMakeLists.txt | 1 + .../internal/domains/protobuf_domain_impl.h | 231 ++++++++++++------ fuzztest/internal/test_protobuf.proto | 5 + 5 files changed, 176 insertions(+), 78 deletions(-) diff --git a/domain_tests/arbitrary_domains_protobuf_test.cc b/domain_tests/arbitrary_domains_protobuf_test.cc index 8ec6150e..8d246b2a 100644 --- a/domain_tests/arbitrary_domains_protobuf_test.cc +++ b/domain_tests/arbitrary_domains_protobuf_test.cc @@ -340,6 +340,22 @@ TEST(ProtocolBufferWithRequiredFields, ShrinkingNeverRemovesRequiredFields) { } } +TEST(ProtocolBufferWithRecursiveFields, InfiniteleyRecursiveFieldsAreNotSet) { + auto domain = Arbitrary() + .WithRepeatedFieldsAlwaysSet() + .WithProtobufFieldUnset("siblings"); + absl::BitGen bitgen; + Value val(domain, bitgen); + + ASSERT_TRUE(val.user_value.IsInitialized()) << val.user_value; + + for (int i = 0; i < 1000; ++i) { + val.Mutate(domain, bitgen, {}, false); + ASSERT_TRUE(val.user_value.IsInitialized()) << val.user_value; + ASSERT_FALSE(val.user_value.has_siblings()) << val.user_value; + } +} + TEST(ProtocolBuffer, CanUsePerFieldDomains) { Domain domain = Arbitrary() diff --git a/fuzztest/BUILD b/fuzztest/BUILD index 45298b00..7e3835a9 100644 --- a/fuzztest/BUILD +++ b/fuzztest/BUILD @@ -354,6 +354,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], diff --git a/fuzztest/CMakeLists.txt b/fuzztest/CMakeLists.txt index d84ed8f6..1c3c9254 100644 --- a/fuzztest/CMakeLists.txt +++ b/fuzztest/CMakeLists.txt @@ -305,6 +305,7 @@ fuzztest_cc_library( absl::status absl::statusor absl::strings + absl::str_format absl::synchronization absl::span ) diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h index dc00eb80..342d04c6 100644 --- a/fuzztest/internal/domains/protobuf_domain_impl.h +++ b/fuzztest/internal/domains/protobuf_domain_impl.h @@ -35,6 +35,7 @@ #include "absl/random/bit_gen_ref.h" #include "absl/random/random.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -233,7 +234,10 @@ class ProtoPolicy { public: ProtoPolicy() : optional_policies_({{/*filter=*/IncludeAll(), - /*value=*/OptionalPolicy::kWithNull}}) {} + /*value=*/OptionalPolicy::kWithNull}}) { + static int64_t next_id = 0; + id_ = next_id++; + } void SetOptionalPolicy(OptionalPolicy optional_policy) { SetOptionalPolicy(IncludeAll(), optional_policy); @@ -314,7 +318,11 @@ class ProtoPolicy { return max; } + int64_t id() const { return id_; } + private: + int64_t id_; + template struct FilterToValue { Filter filter; @@ -459,10 +467,11 @@ class ProtobufDomainUntypedImpl corpus_type Init(absl::BitGenRef prng) { if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed; - FUZZTEST_INTERNAL_CHECK( - !IsCustomizedRecursivelyOnly() || !IsInfinitelyRecursive(), - "Cannot set recursive fields by default."); const auto* descriptor = prototype_.Get()->GetDescriptor(); + FUZZTEST_INTERNAL_CHECK( + !IsCustomizedRecursivelyOnly() || !IsInfinitelyRecursive(descriptor), + absl::StrCat("Cannot set recursive fields for ", + descriptor->full_name(), " by default.")); corpus_type val; absl::flat_hash_map oneof_to_field; @@ -474,15 +483,18 @@ class ProtobufDomainUntypedImpl SelectAFieldIndexInOneof(oneof, prng); } if (oneof_to_field[oneof->index()] != field->index()) continue; - } else if (!MustBeSet(field) && IsCustomizedRecursivelyOnly() && - IsFieldFinitelyRecursive(field)) { - // We avoid initializing non-required recursive fields by default (if - // they are not explicitly customized). Otherwise, the initialization - // may never terminate. If a proto has only non-required recursive - // fields, the initialization will be deterministic, which violates the - // assumption on domain Init. However, such cases should be extremely - // rare and breaking the assumption would not have severe consequences. - continue; + } else if (IsCustomizedRecursivelyOnly()) { + if (!MustBeSet(field) && IsRecursionBreaker(field)) { + // We avoid initializing non-required recursive fields by default (if + // they are not explicitly customized). Otherwise, the initialization + // may never terminate. If a proto has only non-required recursive + // fields, the initialization will be deterministic, which violates + // the assumption on domain Init. However, such cases should be + // extremely rare and breaking the assumption would not have severe + // consequences. + continue; + } + if (MustBeUnset(field)) continue; } VisitProtobufField(field, InitializeVisitor{prng, *this, val}); } @@ -602,6 +614,7 @@ class ProtobufDomainUntypedImpl GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) { continue; } + if (IsCustomizedRecursivelyOnly() && MustBeUnset(field)) continue; ++total_weight; if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { @@ -639,6 +652,7 @@ class ProtobufDomainUntypedImpl GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) { continue; } + if (IsCustomizedRecursivelyOnly() && MustBeUnset(field)) continue; ++field_counter; if (field_counter == selected_field_index) { VisitProtobufField( @@ -951,9 +965,9 @@ class ProtobufDomainUntypedImpl for (int i = 0; i < oneof->field_count(); ++i) { OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i)); if (policy == OptionalPolicy::kAlwaysNull) continue; - if (IsCustomizedRecursivelyOnly() && - IsFieldFinitelyRecursive(oneof->field(i))) { - continue; + if (IsCustomizedRecursivelyOnly()) { + if (IsRecursionBreaker(oneof->field(i))) continue; + if (MustBeUnset(oneof->field(i))) continue; } fields.push_back(i); } @@ -1690,38 +1704,33 @@ class ProtobufDomainUntypedImpl return GetDomainForField(field, /*use_policy=*/false); } - // Analysis type for protobuf recursions. - enum class RecursionType { - // The proto contains a proto of type P, that must contain another P. - kInfinitelyRecursive, - // The proto contains a proto of type P, that can contain another P. - kFinitelyRecursive, - }; - - bool IsInfinitelyRecursive() { - absl::flat_hash_setGetDescriptor())> parents; - return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents, - RecursionType::kInfinitelyRecursive); + // Returns true if there are subprotos in the `descriptor` that form an + // infinite recursion. + bool IsInfinitelyRecursive(const Descriptor* descriptor) const { + FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); + absl::flat_hash_set parents; + return IsProtoRecursive(/*field=*/nullptr, parents, descriptor); } - bool IsFieldFinitelyRecursive(const FieldDescriptor* field) { - if (!field->message_type()) return false; + // Returns true if there are subfields in the `field` that form an + // infinite recursion of the form: F0 -> F1 -> ... -> Fs -> ... -> Fn -> Fs, + // because all Fi-s have to be set (e.g., Fi is a required field, or is + // customized using `WithFieldsAlwaysSet`). + bool IsInfinitelyRecursive(const FieldDescriptor* field) const { + FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); - static absl::NoDestructor> + static absl::NoDestructor< + absl::flat_hash_map, bool>> cache ABSL_GUARDED_BY(mutex); - bool can_use_cache = IsCustomizedRecursivelyOnly(); - if (can_use_cache) { + { absl::MutexLock l(&mutex); - auto it = cache->find(field); + auto it = cache->find({policy_.id(), field}); if (it != cache->end()) return it->second; } - absl::flat_hash_setmessage_type())> parents; - bool result = IsProtoRecursive(field->message_type(), parents, - RecursionType::kFinitelyRecursive); - if (can_use_cache) { - absl::MutexLock l(&mutex); - cache->insert({field, result}); - } + absl::flat_hash_set parents; + bool result = IsProtoRecursive(field, parents); + absl::MutexLock l(&mutex); + cache->insert({{policy_.id(), field}, result}); return result; } @@ -1733,32 +1742,24 @@ class ProtobufDomainUntypedImpl return index == kFieldCountIndex; } - bool IsOneofRecursive(const OneofDescriptor* oneof, - absl::flat_hash_set& parents, - RecursionType recursion_type) const { + bool IsOneofRecursive( + const OneofDescriptor* oneof, + absl::flat_hash_set& parents) const { bool is_oneof_recursive = false; for (int i = 0; i < oneof->field_count(); ++i) { const auto* field = oneof->field(i); const auto field_policy = policy_.GetOptionalPolicy(field); if (field_policy == OptionalPolicy::kAlwaysNull) continue; - const auto* child = field->message_type(); - if (recursion_type == RecursionType::kInfinitelyRecursive) { - is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && - child && - IsProtoRecursive(child, parents, recursion_type); - if (!is_oneof_recursive) { - return false; - } - } else { - if (child && IsProtoRecursive(child, parents, recursion_type)) { - return true; - } - } + is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && + field->message_type() && + IsProtoRecursive(field, parents); + if (!is_oneof_recursive) return false; } return is_oneof_recursive; } bool MustBeSet(const FieldDescriptor* field) const { + FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); if (IsRequired(field)) { return true; } else if (field->containing_oneof()) { @@ -1775,6 +1776,14 @@ class ProtobufDomainUntypedImpl } bool MustBeUnset(const FieldDescriptor* field) const { + FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); + if (field->message_type() && IsInfinitelyRecursive(field)) { + absl::FPrintF( + GetStderr(), + "[!] Infinite recursion detected for %s and it remains unset.\n", + field->full_name()); + return true; + } if (IsRequired(field)) { return false; } else if (field->containing_oneof()) { @@ -1790,39 +1799,105 @@ class ProtobufDomainUntypedImpl return false; } - template - bool IsProtoRecursive(const Descriptor* descriptor, - absl::flat_hash_set& parents, - RecursionType recursion_type) const { - if (parents.contains(descriptor)) return true; - parents.insert(descriptor); + // If `field` is nullptr, all fields of `descriptor` are checked. + bool IsProtoRecursive(const FieldDescriptor* field, + absl::flat_hash_set& parents, + const Descriptor* descriptor = nullptr) const { + if (field != nullptr) { + if (parents.contains(field)) return true; + parents.insert(field); + descriptor = field->message_type(); + } else { + FUZZTEST_INTERNAL_CHECK(descriptor, + "one of field or descriptor must be non-null!"); + } for (int i = 0; i < descriptor->oneof_decl_count(); ++i) { const auto* oneof = descriptor->oneof_decl(i); - if (IsOneofRecursive(oneof, parents, recursion_type)) { - parents.erase(descriptor); + if (IsOneofRecursive(oneof, parents)) { + if (field != nullptr) parents.erase(field); return true; } } - for (const FieldDescriptor* field : GetProtobufFields(descriptor)) { - if (field->containing_oneof()) continue; - const auto* child = field->message_type(); - if (!child) continue; - if (policy_.GetDefaultDomainForProtobufs(field) != std::nullopt) { + for (const FieldDescriptor* subfield : GetProtobufFields(descriptor)) { + if (subfield->containing_oneof()) continue; + if (!subfield->message_type()) continue; + if (auto default_domain = policy_.GetDefaultDomainForProtobufs(subfield); + default_domain != std::nullopt) { // For handling WithProtobufFields. // If this field is recursive, it will be detected when initializing // its default domain. Otherwise, this field can always be set safely. + absl::BitGen prng; + default_domain->Init(prng); continue; } - if (recursion_type == RecursionType::kInfinitelyRecursive) { - if (!MustBeSet(field)) continue; - } else { - if (MustBeUnset(field)) continue; - } - if (IsProtoRecursive(child, parents, recursion_type)) { - parents.erase(descriptor); + if (!MustBeSet(subfield)) continue; + if (IsProtoRecursive(subfield, parents)) { + if (field != nullptr) parents.erase(field); return true; } } - parents.erase(descriptor); + if (field != nullptr) parents.erase(field); + return false; + } + + // A subset of proto types are considered as recursion breakers and during + // domain initialization, won't get recursively initialized to avoid + // non-terminating initialization. + // + // Returns true if the `field` (F0) does not have to be set, and there are + // subfields in the form: F0 -> F1 -> ... -> Fn -> F0 or F20 -> F19 ... -> F0 + // and none of other Fi-s are marked as recursion breakers so far. In other + // words, this method computes recursion breakers and check membership of + // `field` in the set of recursion breakers. + bool IsRecursionBreaker(const FieldDescriptor* field) { + FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); + if (!field->message_type()) return false; + absl::flat_hash_set parents; + return IsRecursionBreaker(field, field, parents); + } + + bool IsRecursionBreaker( + const FieldDescriptor* root, const FieldDescriptor* field, + absl::flat_hash_set& parents) const { + ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); + static absl::NoDestructor< + absl::flat_hash_map, bool>> + cache ABSL_GUARDED_BY(mutex); + { + absl::MutexLock l(&mutex); + auto it = cache->find({policy_.id(), field}); + if (it != cache->end()) return it->second; + } + // Cannot break the recursion for required fields. + bool can_be_unset = !MustBeSet(field); + if (field->containing_oneof() && !can_be_unset) { // oneof must be set + // We check whether `field` is infinitely recursive without considering + // other oneof fields. If it is, there's another field in the oneof that + // can be set. + absl::flat_hash_set subfield_parents; + subfield_parents.insert(field); + can_be_unset = IsProtoRecursive(field, subfield_parents); + } + if (can_be_unset) { + // Break recursion for deeply nested or recursive protos. + if (parents.size() > 20 || parents.contains(field)) { + absl::MutexLock l(&mutex); + cache->insert({{policy_.id(), field}, true}); + return true; + } + parents.insert(field); + } + for (const FieldDescriptor* subfield : + GetProtobufFields(field->message_type())) { + if (!subfield->message_type()) continue; + if (MustBeUnset(subfield)) continue; + IsRecursionBreaker(root, subfield, parents); + } + if (can_be_unset) parents.erase(field); + absl::MutexLock l(&mutex); + cache->insert({{policy_.id(), field}, false}); + auto it = cache->find({policy_.id(), field}); + if (it != cache->end()) return it->second; + cache->insert({{policy_.id(), field}, false}); return false; } diff --git a/fuzztest/internal/test_protobuf.proto b/fuzztest/internal/test_protobuf.proto index dcf03aea..b229af30 100644 --- a/fuzztest/internal/test_protobuf.proto +++ b/fuzztest/internal/test_protobuf.proto @@ -121,6 +121,10 @@ message RecursiveExtender { } } +message TestProtobufWithRepeatedRecursion { + repeated TestProtobufWithRepeatedRecursion siblings = 1; +} + message TestProtobufWithRecursion { message ChildProto { optional TestProtobufWithRecursion parent1 = 1; @@ -134,6 +138,7 @@ message TestProtobufWithRecursion { int32 child_id = 3; } optional TestProtobufWithExtension ext = 4; + optional TestProtobufWithRepeatedRecursion siblings = 5; } message MessageWithGroup {